-
Notifications
You must be signed in to change notification settings - Fork 0
Feat(#27): add input validation #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,12 @@ import type { | |
| import { | ||
| type EmitContext, | ||
| emitFile, | ||
| getFormat, | ||
| getMaxLength, | ||
| getMaxValue, | ||
| getMinLength, | ||
| getMinValue, | ||
| getPattern, | ||
| resolvePath, | ||
| walkPropertiesInherited, | ||
| } from "@typespec/compiler"; | ||
|
|
@@ -45,6 +51,231 @@ function extractDefaultValue( | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Determines if a scalar type is an integer type. | ||
| */ | ||
| function isIntegerType(type: Scalar): boolean { | ||
| const integerTypes = [ | ||
| "integer", | ||
| "int64", | ||
| "int32", | ||
| "int16", | ||
| "int8", | ||
| "uint64", | ||
| "uint32", | ||
| "uint16", | ||
| "uint8", | ||
| "safeint", | ||
| ]; | ||
|
|
||
| // Check the type itself first | ||
| if (integerTypes.includes(type.name)) { | ||
| return true; | ||
| } | ||
|
|
||
| // Walk up the base scalar chain | ||
| let baseType = type.baseScalar; | ||
| while (baseType) { | ||
| if (integerTypes.includes(baseType.name)) { | ||
| return true; | ||
| } | ||
| baseType = baseType.baseScalar; | ||
| } | ||
|
|
||
| return false; | ||
| } | ||
|
|
||
| /** | ||
| * Determines if a scalar type is a float type. | ||
| */ | ||
| function isFloatType(type: Scalar): boolean { | ||
| const floatTypes = ["float", "float32", "float64", "decimal", "decimal128"]; | ||
|
|
||
| // Check the type itself first | ||
| if (floatTypes.includes(type.name)) { | ||
| return true; | ||
| } | ||
|
|
||
| // Walk up the base scalar chain | ||
| let baseType = type.baseScalar; | ||
| while (baseType) { | ||
| if (floatTypes.includes(baseType.name)) { | ||
| return true; | ||
| } | ||
| baseType = baseType.baseScalar; | ||
| } | ||
|
|
||
| return false; | ||
| } | ||
|
|
||
| /** | ||
| * Determines if a scalar type is a date/time type. | ||
| */ | ||
| function isDateTimeType(type: Scalar): boolean { | ||
| const dateTimeTypes = [ | ||
| "utcDateTime", | ||
| "offsetDateTime", | ||
| "plainDate", | ||
| "plainTime", | ||
| ]; | ||
|
|
||
| // Check the type itself first | ||
| if (dateTimeTypes.includes(type.name)) { | ||
| return true; | ||
| } | ||
|
|
||
| // Walk up the base scalar chain | ||
| let baseType = type.baseScalar; | ||
| while (baseType) { | ||
| if (dateTimeTypes.includes(baseType.name)) { | ||
| return true; | ||
| } | ||
| baseType = baseType.baseScalar; | ||
| } | ||
|
|
||
| return false; | ||
| } | ||
|
|
||
| /** | ||
| * Gets the base scalar name for a type (for datetime type identification). | ||
| */ | ||
| function getBaseScalarName(type: Scalar): string { | ||
| const dateTimeTypes = [ | ||
| "utcDateTime", | ||
| "offsetDateTime", | ||
| "plainDate", | ||
| "plainTime", | ||
| ]; | ||
|
|
||
| // Check the type itself first | ||
| if (dateTimeTypes.includes(type.name)) { | ||
| return type.name; | ||
| } | ||
|
|
||
| // Walk up the base scalar chain to find a datetime type | ||
| let baseType = type.baseScalar; | ||
| while (baseType) { | ||
| if (dateTimeTypes.includes(baseType.name)) { | ||
| return baseType.name; | ||
| } | ||
| baseType = baseType.baseScalar; | ||
| } | ||
|
|
||
| return type.name; | ||
| } | ||
|
|
||
| interface ValidationConstraints { | ||
| minLength?: number; | ||
| maxLength?: number; | ||
| minValue?: number; | ||
| maxValue?: number; | ||
| pattern?: string; | ||
| format?: string; | ||
| isInteger?: boolean; | ||
| isFloat?: boolean; | ||
| isDateTime?: boolean; | ||
| dateTimeType?: string; | ||
| enumValues?: string[]; | ||
| } | ||
|
|
||
| /** | ||
| * Builds a validation function for ElectroDB based on constraints. | ||
| */ | ||
| function buildValidationFunction( | ||
| constraints: ValidationConstraints, | ||
| ): ((value: unknown) => void) | undefined { | ||
| const checks: string[] = []; | ||
|
|
||
| // String length validation | ||
| if (constraints.minLength !== undefined) { | ||
| checks.push( | ||
| `if (typeof value === "string" && value.length < ${constraints.minLength}) throw new Error("Value must be at least ${constraints.minLength} characters")`, | ||
| ); | ||
| } | ||
| if (constraints.maxLength !== undefined) { | ||
| checks.push( | ||
| `if (typeof value === "string" && value.length > ${constraints.maxLength}) throw new Error("Value must be at most ${constraints.maxLength} characters")`, | ||
| ); | ||
| } | ||
|
|
||
| // Numeric validation | ||
| if (constraints.minValue !== undefined) { | ||
| checks.push( | ||
| `if (typeof value === "number" && value < ${constraints.minValue}) throw new Error("Value must be at least ${constraints.minValue}")`, | ||
| ); | ||
| } | ||
| if (constraints.maxValue !== undefined) { | ||
| checks.push( | ||
| `if (typeof value === "number" && value > ${constraints.maxValue}) throw new Error("Value must be at most ${constraints.maxValue}")`, | ||
| ); | ||
| } | ||
|
|
||
| // Integer validation | ||
| if (constraints.isInteger) { | ||
| checks.push( | ||
| `if (typeof value === "number" && !Number.isInteger(value)) throw new Error("Value must be an integer")`, | ||
| ); | ||
| } | ||
|
|
||
| // Float validation (ensure it's a finite number) | ||
| if (constraints.isFloat) { | ||
| checks.push( | ||
| `if (typeof value === "number" && !Number.isFinite(value)) throw new Error("Value must be a finite number")`, | ||
| ); | ||
| } | ||
|
|
||
| // Pattern validation | ||
| if (constraints.pattern) { | ||
| const escapedPattern = constraints.pattern.replace(/\\/g, "\\\\"); | ||
| checks.push( | ||
| `if (typeof value === "string" && !new RegExp("${escapedPattern}").test(value)) throw new Error("Value must match pattern ${escapedPattern}")`, | ||
| ); | ||
|
Comment on lines
+228
to
+232
|
||
| } | ||
|
|
||
| // DateTime validation | ||
| if (constraints.isDateTime && constraints.dateTimeType) { | ||
| switch (constraints.dateTimeType) { | ||
| case "utcDateTime": | ||
| checks.push( | ||
| `if (typeof value === "string") { const d = new Date(value); if (isNaN(d.getTime())) throw new Error("Value must be a valid UTC date-time string"); }`, | ||
| ); | ||
| break; | ||
| case "offsetDateTime": | ||
| checks.push( | ||
| `if (typeof value === "string") { const d = new Date(value); if (isNaN(d.getTime())) throw new Error("Value must be a valid offset date-time string"); }`, | ||
| ); | ||
| break; | ||
| case "plainDate": | ||
| checks.push( | ||
| `if (typeof value === "string" && !/^\\d{4}-\\d{2}-\\d{2}$/.test(value)) throw new Error("Value must be a valid date (YYYY-MM-DD)")`, | ||
| ); | ||
| break; | ||
| case "plainTime": | ||
| checks.push( | ||
| `if (typeof value === "string" && !/^\\d{2}:\\d{2}(:\\d{2})?(\\.\\d+)?$/.test(value)) throw new Error("Value must be a valid time (HH:MM:SS)")`, | ||
| ); | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| // Enum validation | ||
| if (constraints.enumValues && constraints.enumValues.length > 0) { | ||
| const allowedValues = JSON.stringify(constraints.enumValues); | ||
| checks.push( | ||
| `if (!${allowedValues}.includes(value)) throw new Error("Value must be one of: ${constraints.enumValues.join(", ")}")`, | ||
| ); | ||
|
Comment on lines
+262
to
+266
|
||
| } | ||
|
|
||
| if (checks.length === 0) { | ||
| return undefined; | ||
| } | ||
|
|
||
| // Create the validation function as a string to be serialized | ||
| const functionBody = checks.join("; "); | ||
| // biome-ignore lint/security/noGlobalEval: This is safe since we control the input | ||
| return eval(`(value) => { ${functionBody} }`); | ||
| } | ||
|
|
||
| function emitIntrinsincScalar(type: Scalar) { | ||
| switch (type.name) { | ||
| case "boolean": | ||
|
|
@@ -269,6 +500,93 @@ function emitModelProperty(prop: ModelProperty): Attribute { | |
| const getLabel = (ctx: EmitContext, prop: ModelProperty) => | ||
| ctx.program.stateMap(StateKeys.label).get(prop); | ||
|
|
||
| /** | ||
| * Extracts validation constraints from a ModelProperty and its type. | ||
| */ | ||
| function getValidationConstraints( | ||
| ctx: EmitContext, | ||
| prop: ModelProperty, | ||
| ): ValidationConstraints { | ||
| const constraints: ValidationConstraints = {}; | ||
| const program = ctx.program; | ||
|
|
||
| // Get constraints from the property itself | ||
| const propMinLength = getMinLength(program, prop); | ||
| const propMaxLength = getMaxLength(program, prop); | ||
| const propMinValue = getMinValue(program, prop); | ||
| const propMaxValue = getMaxValue(program, prop); | ||
| const propPattern = getPattern(program, prop); | ||
| const propFormat = getFormat(program, prop); | ||
|
|
||
| if (propMinLength !== undefined) constraints.minLength = propMinLength; | ||
| if (propMaxLength !== undefined) constraints.maxLength = propMaxLength; | ||
| if (propMinValue !== undefined) constraints.minValue = propMinValue; | ||
| if (propMaxValue !== undefined) constraints.maxValue = propMaxValue; | ||
| if (propPattern !== undefined) constraints.pattern = propPattern; | ||
| if (propFormat !== undefined) constraints.format = propFormat; | ||
|
|
||
| // Get constraints from the type (Scalar types may have constraints applied to them) | ||
| if (prop.type.kind === "Scalar") { | ||
| let scalarType: Scalar | undefined = prop.type; | ||
|
|
||
| // Walk up the scalar hierarchy to collect constraints | ||
| while (scalarType) { | ||
| const typeMinLength = getMinLength(program, scalarType); | ||
| const typeMaxLength = getMaxLength(program, scalarType); | ||
| const typeMinValue = getMinValue(program, scalarType); | ||
| const typeMaxValue = getMaxValue(program, scalarType); | ||
| const typePattern = getPattern(program, scalarType); | ||
| const typeFormat = getFormat(program, scalarType); | ||
|
|
||
| // Only set if not already set (property constraints take precedence) | ||
| if (typeMinLength !== undefined && constraints.minLength === undefined) | ||
| constraints.minLength = typeMinLength; | ||
| if (typeMaxLength !== undefined && constraints.maxLength === undefined) | ||
| constraints.maxLength = typeMaxLength; | ||
| if (typeMinValue !== undefined && constraints.minValue === undefined) | ||
| constraints.minValue = typeMinValue; | ||
| if (typeMaxValue !== undefined && constraints.maxValue === undefined) | ||
| constraints.maxValue = typeMaxValue; | ||
| if (typePattern !== undefined && constraints.pattern === undefined) | ||
| constraints.pattern = typePattern; | ||
| if (typeFormat !== undefined && constraints.format === undefined) | ||
| constraints.format = typeFormat; | ||
|
|
||
| scalarType = scalarType.baseScalar; | ||
| } | ||
|
|
||
| // Check if the base type requires integer or float validation | ||
| if (isIntegerType(prop.type)) { | ||
| constraints.isInteger = true; | ||
| } else if (isFloatType(prop.type)) { | ||
| constraints.isFloat = true; | ||
| } | ||
|
|
||
| // Check for datetime types | ||
| if (isDateTimeType(prop.type)) { | ||
| constraints.isDateTime = true; | ||
| constraints.dateTimeType = getBaseScalarName(prop.type); | ||
| } | ||
| } | ||
|
|
||
| // Check for enum types | ||
| if (prop.type.kind === "Enum") { | ||
| constraints.enumValues = Array.from(prop.type.members).map( | ||
| ([key, member]) => `${member.value ?? key}`, | ||
| ); | ||
| } | ||
|
|
||
| // Check for literal unions (e.g., "home" | "work" | "other") | ||
| if (prop.type.kind === "Union") { | ||
| const literals = isLiteralUnion(prop.type); | ||
| if (literals) { | ||
| constraints.enumValues = literals; | ||
| } | ||
| } | ||
|
|
||
| return constraints; | ||
| } | ||
|
|
||
| function emitAttribute(ctx: EmitContext, prop: ModelProperty): Attribute { | ||
| const type = emitType(prop.type); | ||
|
|
||
|
|
@@ -318,6 +636,14 @@ function emitAttribute(ctx: EmitContext, prop: ModelProperty): Attribute { | |
| } | ||
| } | ||
|
|
||
| // Add validation if constraints are present | ||
| const constraints = getValidationConstraints(ctx, prop); | ||
| const validateFn = buildValidationFunction(constraints); | ||
| if (validateFn) { | ||
| // @ts-expect-error - validate is a valid ElectroDB attribute property | ||
| attr.validate = validateFn; | ||
| } | ||
|
|
||
| return attr; | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The getBaseScalarName function duplicates the dateTimeTypes array definition from isDateTimeType. Consider extracting this array to a module-level constant to avoid duplication and ensure consistency if the list needs to be updated in the future.