|
291 | 291 | " return value\n", |
292 | 292 | " return MetricResult(result=value)\n", |
293 | 293 | "\n", |
| 294 | + "@patch\n", |
| 295 | + "def __json__(self: MetricResult):\n", |
| 296 | + " \"\"\"Return data for JSON serialization.\n", |
| 297 | + " \n", |
| 298 | + " This method is used by json.dumps and other JSON serializers \n", |
| 299 | + " to convert MetricResult to a JSON-compatible format.\n", |
| 300 | + " \"\"\"\n", |
| 301 | + " return {\n", |
| 302 | + " \"result\": self._result,\n", |
| 303 | + " \"reason\": self.reason,\n", |
| 304 | + " }\n", |
| 305 | + "\n", |
294 | 306 | "# Add Pydantic compatibility methods\n", |
295 | 307 | "@patch(cls_method=True)\n", |
296 | 308 | "def __get_pydantic_core_schema__(\n", |
|
299 | 311 | " _handler: GetCoreSchemaHandler\n", |
300 | 312 | ") -> core_schema.CoreSchema:\n", |
301 | 313 | " \"\"\"Generate a Pydantic core schema for MetricResult.\"\"\"\n", |
302 | | - " return core_schema.with_info_plain_validator_function(cls.validate)\n", |
| 314 | + " # Create a schema that handles both validation and serialization\n", |
| 315 | + " return core_schema.union_schema([\n", |
| 316 | + " # First schema: handles validation of MetricResult instances\n", |
| 317 | + " core_schema.is_instance_schema(MetricResult),\n", |
| 318 | + " \n", |
| 319 | + " # Second schema: handles validation of other values and conversion to MetricResult\n", |
| 320 | + " core_schema.chain_schema([\n", |
| 321 | + " core_schema.any_schema(),\n", |
| 322 | + " core_schema.no_info_plain_validator_function(\n", |
| 323 | + " lambda value: MetricResult(result=value) if not isinstance(value, MetricResult) else value\n", |
| 324 | + " ),\n", |
| 325 | + " ]),\n", |
| 326 | + " ], serialization=core_schema.plain_serializer_function_ser_schema(\n", |
| 327 | + " # This function handles serialization\n", |
| 328 | + " lambda instance: instance.__json__()\n", |
| 329 | + " ))\n", |
303 | 330 | "\n", |
304 | 331 | "\n", |
305 | 332 | "@patch\n", |
|
369 | 396 | { |
370 | 397 | "cell_type": "code", |
371 | 398 | "execution_count": null, |
372 | | - "id": "9d32b10f", |
| 399 | + "id": "27f9bc1c", |
373 | 400 | "metadata": {}, |
374 | 401 | "outputs": [ |
375 | 402 | { |
376 | 403 | "data": { |
377 | 404 | "text/plain": [ |
378 | | - "'test'" |
| 405 | + "'{\"response\":\"test\",\"grade\":{\"result\":1,\"reason\":\"test\"},\"faithfulness\":{\"result\":1,\"reason\":\"test\"}}'" |
379 | 406 | ] |
380 | 407 | }, |
381 | 408 | "execution_count": null, |
|
384 | 411 | } |
385 | 412 | ], |
386 | 413 | "source": [ |
387 | | - "m.model_dump()[\"faithfulness\"].reason" |
| 414 | + "mt = TestModel(response=\"test\", grade=MetricResult(result=1, reason=\"test\"), faithfulness=MetricResult(result=1, reason=\"test\"))\n", |
| 415 | + "\n", |
| 416 | + "mt.model_dump_json()" |
388 | 417 | ] |
389 | | - }, |
390 | | - { |
391 | | - "cell_type": "code", |
392 | | - "execution_count": null, |
393 | | - "id": "bde70d56", |
394 | | - "metadata": {}, |
395 | | - "outputs": [], |
396 | | - "source": [] |
397 | 418 | } |
398 | 419 | ], |
399 | 420 | "metadata": { |
|
0 commit comments