-
Notifications
You must be signed in to change notification settings - Fork 22
Description
I am using Boundary baml-py package which generates some Pydantic classes at abzu.baml_client.types that look like the schema below:
class Exchange(str, Enum):
NYSE = "NYSE"
NASDAQ = "NASDAQ"
AMEX = "AMEX"
ARCA = "ARCA"
TSXV = "TSXV"
NYSEMKT = "NYSEMKT"
LSE = "LSE"
TSE = "TSE"
HKEX = "HKEX"
SGX = "SGX"
BSE = "BSE"
ASX = "ASX"
SSE = "SSE"
MSE = "MSE"
CSE = "CSE"
BMV = "BMV"
OSE = "OSE"
BME = "BME"
SWX = "SWX"class Ticker(BaseModel):
id: int
uuid: typing.Optional[str] = None
symbol: str
exchange: typing.Optional[Exchange] = Noneclass Company(BaseModel):
id: int
uuid: typing.Optional[str] = None
name: str
cik: typing.Optional[str] = None
ticker: typing.Optional["Ticker"] = None
description: str
website_url: typing.Optional[str] = None
headquarters_location: typing.Optional[str] = None
jurisdiction: typing.Optional[str] = None
revenue_usd: typing.Optional[int] = None
employees: typing.Optional[int] = None
founded_year: typing.Optional[int] = None
ceo: typing.Optional[str] = None
linkedin_url: typing.Optional[str] = None
source_ids: typing.Optional[typing.List[int]] = None
source_uuids: typing.Optional[typing.List[str]] = NoneI get the following errors when I try to generate a PySpark schema. Looking at the docs, I think I understand that sparkdantic isn't able to convert the class Ticker(BaseModel) into a T.StructType(), as I desire. Is this something I could somehow fix or alter the library to support?
from abzu.baml_client.types import Company, CompanyList
from sparkdantic import SparkModel
class SparkCompany(Company, SparkModel):
pass
SparkCompany.model_spark_schema()
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/anaconda3/envs/weave/lib/python3.12/site-packages/sparkdantic/model.py:238, in create_json_spark_schema(model, safe_casting, by_alias, mode, exclude_fields)
237 metadata = _get_metadata(info)
--> 238 spark_type = _from_python_type(
239 field_type, metadata, safe_casting, by_alias, mode, exclude_fields
240 )
241 except Exception as raised_error:
File ~/anaconda3/envs/weave/lib/python3.12/site-packages/sparkdantic/model.py:413, in _from_python_type(type_, metadata, safe_casting, by_alias, mode, exclude_fields)
411 py_type = args[0]
--> 413 if issubclass(py_type, Enum):
414 py_type = _get_enum_mixin_type(py_type)
TypeError: issubclass() arg 1 must be a class
The above exception was the direct cause of the following exception:
TypeConversionError Traceback (most recent call last)
Cell In[4], line 1
----> 1 SparkCompany.model_spark_schema()
File ~/anaconda3/envs/weave/lib/python3.12/site-packages/sparkdantic/model.py:121, in SparkModel.model_spark_schema(cls, safe_casting, by_alias, mode, exclude_fields)
100 @classmethod
101 def model_spark_schema(
102 cls,
(...)
106 exclude_fields: bool = False,
107 ) -> 'StructType':
108 """Generates a PySpark schema from the model fields. This operates similarly to
109 `pydantic.BaseModel.model_json_schema()`.
110
(...)
119 pyspark.sql.types.StructType: The generated PySpark schema.
120 """
--> 121 return create_spark_schema(cls, safe_casting, by_alias, mode, exclude_fields)
File ~/anaconda3/envs/weave/lib/python3.12/site-packages/sparkdantic/model.py:306, in create_spark_schema(model, safe_casting, by_alias, mode, exclude_fields)
292 """Generates a PySpark schema from the model fields.
293
294 Args:
(...)
303 pyspark.sql.types.StructType: The generated PySpark schema.
304 """
305 utils.require_pyspark()
--> 306 json_schema = create_json_spark_schema(model, safe_casting, by_alias, mode, exclude_fields)
307 return StructType.fromJson(json_schema)
File ~/anaconda3/envs/weave/lib/python3.12/site-packages/sparkdantic/model.py:242, in create_json_spark_schema(model, safe_casting, by_alias, mode, exclude_fields)
238 spark_type = _from_python_type(
239 field_type, metadata, safe_casting, by_alias, mode, exclude_fields
240 )
241 except Exception as raised_error:
--> 242 raise TypeConversionError(
243 f'Error converting field `{name}` to PySpark type'
244 ) from raised_error
246 nullable = _is_optional(annotation_or_return_type)
247 struct_field: Dict[str, Any] = {
248 'name': name,
249 'type': spark_type,
250 'nullable': nullable,
251 'metadata': comment,
252 }
TypeConversionError: Error converting field `ticker` to PySpark typeIt looks like the Exchange Enum might work, but it is a Exchange(str, Enum) not an Exchange(str, StrEnum) which the docs require. I don't understand what about class Ticker(BaseModel) makes it unable to convert, but I'm going to dig into the code and see if you support pyspark.sql.types.StructType.
Any help would be appreciated, I'm building an entity resolution pipeline that uses PySpark to group records into blocks for comparison, BAML to generate Pydantic classes, which are then used via the BAMLClient to do entity matching via an LLM. I can't directly modify my Pydantic classes, but I will happily upgrade sparkdantic if you can point me in the right direction :)