Skip to content

SparkModel.model_spark_schema() can't convert a nested BaseModel field #798

@rjurney

Description

@rjurney

I am using Boundary baml-py package which generates some Pydantic classes at abzu.baml_client.types that look like the schema below:

class Exchange(str, Enum):
    NYSE = "NYSE"
    NASDAQ = "NASDAQ"
    AMEX = "AMEX"
    ARCA = "ARCA"
    TSXV = "TSXV"
    NYSEMKT = "NYSEMKT"
    LSE = "LSE"
    TSE = "TSE"
    HKEX = "HKEX"
    SGX = "SGX"
    BSE = "BSE"
    ASX = "ASX"
    SSE = "SSE"
    MSE = "MSE"
    CSE = "CSE"
    BMV = "BMV"
    OSE = "OSE"
    BME = "BME"
    SWX = "SWX"
class Ticker(BaseModel):
    id: int
    uuid: typing.Optional[str] = None
    symbol: str
    exchange: typing.Optional[Exchange] = None
class Company(BaseModel):
    id: int
    uuid: typing.Optional[str] = None
    name: str
    cik: typing.Optional[str] = None
    ticker: typing.Optional["Ticker"] = None
    description: str
    website_url: typing.Optional[str] = None
    headquarters_location: typing.Optional[str] = None
    jurisdiction: typing.Optional[str] = None
    revenue_usd: typing.Optional[int] = None
    employees: typing.Optional[int] = None
    founded_year: typing.Optional[int] = None
    ceo: typing.Optional[str] = None
    linkedin_url: typing.Optional[str] = None
    source_ids: typing.Optional[typing.List[int]] = None
    source_uuids: typing.Optional[typing.List[str]] = None

I get the following errors when I try to generate a PySpark schema. Looking at the docs, I think I understand that sparkdantic isn't able to convert the class Ticker(BaseModel) into a T.StructType(), as I desire. Is this something I could somehow fix or alter the library to support?

from abzu.baml_client.types import Company, CompanyList

from sparkdantic import SparkModel

class SparkCompany(Company, SparkModel):
    pass

SparkCompany.model_spark_schema()

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/anaconda3/envs/weave/lib/python3.12/site-packages/sparkdantic/model.py:238, in create_json_spark_schema(model, safe_casting, by_alias, mode, exclude_fields)
    237         metadata = _get_metadata(info)
--> 238         spark_type = _from_python_type(
    239             field_type, metadata, safe_casting, by_alias, mode, exclude_fields
    240         )
    241 except Exception as raised_error:

File ~/anaconda3/envs/weave/lib/python3.12/site-packages/sparkdantic/model.py:413, in _from_python_type(type_, metadata, safe_casting, by_alias, mode, exclude_fields)
    411     py_type = args[0]
--> 413 if issubclass(py_type, Enum):
    414     py_type = _get_enum_mixin_type(py_type)

TypeError: issubclass() arg 1 must be a class

The above exception was the direct cause of the following exception:

TypeConversionError                       Traceback (most recent call last)
Cell In[4], line 1
----> 1 SparkCompany.model_spark_schema()

File ~/anaconda3/envs/weave/lib/python3.12/site-packages/sparkdantic/model.py:121, in SparkModel.model_spark_schema(cls, safe_casting, by_alias, mode, exclude_fields)
    100 @classmethod
    101 def model_spark_schema(
    102     cls,
   (...)
    106     exclude_fields: bool = False,
    107 ) -> 'StructType':
    108     """Generates a PySpark schema from the model fields. This operates similarly to
    109     `pydantic.BaseModel.model_json_schema()`.
    110 
   (...)
    119         pyspark.sql.types.StructType: The generated PySpark schema.
    120     """
--> 121     return create_spark_schema(cls, safe_casting, by_alias, mode, exclude_fields)

File ~/anaconda3/envs/weave/lib/python3.12/site-packages/sparkdantic/model.py:306, in create_spark_schema(model, safe_casting, by_alias, mode, exclude_fields)
    292 """Generates a PySpark schema from the model fields.
    293 
    294 Args:
   (...)
    303     pyspark.sql.types.StructType: The generated PySpark schema.
    304 """
    305 utils.require_pyspark()
--> 306 json_schema = create_json_spark_schema(model, safe_casting, by_alias, mode, exclude_fields)
    307 return StructType.fromJson(json_schema)

File ~/anaconda3/envs/weave/lib/python3.12/site-packages/sparkdantic/model.py:242, in create_json_spark_schema(model, safe_casting, by_alias, mode, exclude_fields)
    238         spark_type = _from_python_type(
    239             field_type, metadata, safe_casting, by_alias, mode, exclude_fields
    240         )
    241 except Exception as raised_error:
--> 242     raise TypeConversionError(
    243         f'Error converting field `{name}` to PySpark type'
    244     ) from raised_error
    246 nullable = _is_optional(annotation_or_return_type)
    247 struct_field: Dict[str, Any] = {
    248     'name': name,
    249     'type': spark_type,
    250     'nullable': nullable,
    251     'metadata': comment,
    252 }

TypeConversionError: Error converting field `ticker` to PySpark type

It looks like the Exchange Enum might work, but it is a Exchange(str, Enum) not an Exchange(str, StrEnum) which the docs require. I don't understand what about class Ticker(BaseModel) makes it unable to convert, but I'm going to dig into the code and see if you support pyspark.sql.types.StructType.

Any help would be appreciated, I'm building an entity resolution pipeline that uses PySpark to group records into blocks for comparison, BAML to generate Pydantic classes, which are then used via the BAMLClient to do entity matching via an LLM. I can't directly modify my Pydantic classes, but I will happily upgrade sparkdantic if you can point me in the right direction :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions