@@ -33,11 +33,14 @@ class DaskJsonDict(JsonDict):
3333 name : str
3434 chunks : Iterable [tuple [int , ...]]
3535 dtype : str
36+ shape : Union [tuple [int , ...], None ] = None
3637 value : list
3738
3839 def to_array_input (self ) -> DaskArray :
3940 """Construct a dask array"""
4041 np_array = np .array (self .value , dtype = self .dtype )
42+ if self .shape is not None and np_array .shape != self .shape :
43+ np_array = self .reshape_input (np_array , self .shape )
4144 array = from_array (
4245 np_array ,
4346 name = self .name ,
@@ -75,28 +78,37 @@ def before_validation(self, array: DaskArray) -> NDArrayType:
7578 Try and coerce dicts that should be model objects into the model objects
7679 """
7780 try :
78- if issubclass (self .dtype , BaseModel ) and isinstance (
79- array .reshape (- 1 )[0 ].compute (), dict
80- ):
81+ if issubclass (self .dtype , BaseModel ):
82+ flat_array = array .reshape (- 1 )
83+ if len (flat_array ) == 0 :
84+ return array
8185
82- def _chunked_to_model (array : np .ndarray ) -> np .ndarray :
83- def _vectorized_to_model (item : Union [dict , BaseModel ]) -> BaseModel :
84- if not isinstance (item , self .dtype ):
85- return self .dtype (** item )
86- else : # pragma: no cover
87- return item
86+ if isinstance (flat_array [0 ].compute (), dict ):
8887
89- return np .vectorize (_vectorized_to_model )(array )
88+ def _chunked_to_model (array : np .ndarray ) -> np .ndarray :
89+ def _vectorized_to_model (
90+ item : Union [dict , BaseModel ],
91+ ) -> BaseModel :
92+ if not isinstance (item , self .dtype ):
93+ return self .dtype (** item )
94+ else : # pragma: no cover
95+ return item
9096
91- array = array .map_blocks (_chunked_to_model , dtype = self .dtype )
97+ return np .vectorize (_vectorized_to_model )(array )
98+
99+ array = array .map_blocks (_chunked_to_model , dtype = self .dtype )
92100 except TypeError :
93101 # fine, dtype isn't a type
94102 pass
95103 return array
96104
97105 def get_object_dtype (self , array : NDArrayType ) -> DtypeType :
98106 """Dask arrays require a compute() call to retrieve a single value"""
99- return type (array .reshape (- 1 )[0 ].compute ())
107+ flat_array = array .reshape (- 1 )
108+ if len (flat_array ) == 0 :
109+ return Any
110+ else :
111+ return type (flat_array [0 ].compute ())
100112
101113 @classmethod
102114 def enabled (cls ) -> bool :
@@ -121,12 +133,15 @@ def to_json(
121133 """
122134 np_array = np .array (array )
123135 as_json = np_array .tolist ()
136+ if not isinstance (as_json , list ):
137+ as_json = [as_json ]
124138 if info .round_trip :
125139 as_json = DaskJsonDict (
126140 type = cls .name ,
127141 value = as_json ,
128142 name = array .name ,
129143 chunks = array .chunks ,
130144 dtype = str (np_array .dtype ),
145+ shape = array .shape ,
131146 )
132147 return as_json
0 commit comments