|
1 | | -from typing import Any, Type |
| 1 | +from typing import Any, Dict, Type |
2 | 2 |
|
3 | 3 | from graphql import GraphQLError, ValidationRule |
4 | | -from graphql.language.ast import DocumentNode, FieldNode, OperationDefinitionNode |
| 4 | +from graphql.language.ast import ( |
| 5 | + DocumentNode, |
| 6 | + FieldNode, |
| 7 | + OperationDefinitionNode, |
| 8 | + VariableDefinitionNode, |
| 9 | +) |
5 | 10 | from graphql.validation import ValidationContext |
6 | 11 |
|
7 | 12 |
|
| 13 | +class MissingVariablesError(Exception): |
| 14 | + """ |
| 15 | + Custom error class to represent errors where required variables defined in the query does |
| 16 | + not have a matching definition in the variables part of the request. Normally when this |
| 17 | + scenario occurs it would raise a GraphQLError type but that would cause a uncaught |
| 18 | + exception for some reason. The aim of this is to surface the error in the response clearly |
| 19 | + and to prevent internal server errors when it occurs. |
| 20 | + """ |
| 21 | + |
| 22 | + pass |
| 23 | + |
| 24 | + |
| 25 | +def create_required_variables_rule(variables: Dict) -> Type[ValidationRule]: |
| 26 | + class RequiredVariablesValidationRule(ValidationRule): |
| 27 | + def __init__(self, context: ValidationContext) -> None: |
| 28 | + super().__init__(context) |
| 29 | + self.variables = variables |
| 30 | + |
| 31 | + def enter_operation_definition( |
| 32 | + self, node: OperationDefinitionNode, *_args: Any |
| 33 | + ) -> None: |
| 34 | + # Get variable definitions |
| 35 | + variable_definitions = node.variable_definitions or [] |
| 36 | + |
| 37 | + # Extract variables marked as Non Null |
| 38 | + required_variables = [ |
| 39 | + var_def.variable.name.value |
| 40 | + for var_def in variable_definitions |
| 41 | + if isinstance(var_def, VariableDefinitionNode) |
| 42 | + and var_def.type.kind == "non_null_type" |
| 43 | + ] |
| 44 | + |
| 45 | + # Check if these required variables are provided |
| 46 | + missing_variables = [ |
| 47 | + var for var in required_variables if var not in self.variables |
| 48 | + ] |
| 49 | + if missing_variables: |
| 50 | + raise MissingVariablesError( |
| 51 | + f"Missing required variables: {', '.join(missing_variables)}", |
| 52 | + ) |
| 53 | + |
| 54 | + return RequiredVariablesValidationRule |
| 55 | + |
| 56 | + |
8 | 57 | def create_max_depth_rule(max_depth: int) -> Type[ValidationRule]: |
9 | 58 | class MaxDepthRule(ValidationRule): |
10 | 59 | def __init__(self, context: ValidationContext) -> None: |
|
0 commit comments