33
44from __future__ import annotations
55
6- from collections .abc import Sequence
6+ import enum
7+ from collections .abc import Iterable
8+ from inspect import isclass
79from typing import Any
810
911import polars as pl
@@ -22,7 +24,7 @@ class Enum(Column):
2224
2325 def __init__ (
2426 self ,
25- categories : Sequence [str ],
27+ categories : pl . Series | Iterable [str ] | type [ enum . Enum ],
2628 * ,
2729 nullable : bool | None = None ,
2830 primary_key : bool = False ,
@@ -32,7 +34,8 @@ def __init__(
3234 ):
3335 """
3436 Args:
35- categories: The list of valid categories for the enum.
37+ categories: The set of valid categories for the enum, or an existing Python
38+ string-valued enum.
3639 nullable: Whether this column may contain null values.
3740 Explicitly set `nullable=True` if you want your column to be nullable.
3841 In a future release, `nullable=False` will be the default if `nullable`
@@ -63,7 +66,13 @@ def __init__(
6366 alias = alias ,
6467 metadata = metadata ,
6568 )
66- self .categories = list (categories )
69+ if isclass (categories ) and issubclass (categories , enum .Enum ):
70+ categories = pl .Series (
71+ values = [getattr (v , "value" , v ) for v in categories .__members__ .values ()]
72+ )
73+ elif not isinstance (categories , pl .Series ):
74+ categories = pl .Series (values = categories )
75+ self .categories = categories
6776
6877 @property
6978 def dtype (self ) -> pl .DataType :
@@ -72,7 +81,7 @@ def dtype(self) -> pl.DataType:
7281 def validate_dtype (self , dtype : PolarsDataType ) -> bool :
7382 if not isinstance (dtype , pl .Enum ):
7483 return False
75- return self .categories == dtype .categories . to_list ( )
84+ return self .categories . equals ( dtype .categories )
7685
7786 def sqlalchemy_dtype (self , dialect : sa .Dialect ) -> sa_TypeEngine :
7887 category_lengths = [len (c ) for c in self .categories ]
@@ -92,5 +101,7 @@ def pyarrow_dtype(self) -> pa.DataType:
92101
93102 def _sample_unchecked (self , generator : Generator , n : int ) -> pl .Series :
94103 return generator .sample_choice (
95- n , choices = self .categories , null_probability = self ._null_probability
104+ n ,
105+ choices = self .categories .to_list (),
106+ null_probability = self ._null_probability ,
96107 ).cast (self .dtype )
0 commit comments