@@ -320,6 +320,39 @@ def backmap_probs(self, probs, n_values, backmappers):
320320 new_probs = new_probs / tots [:, None ]
321321 return new_probs
322322
323+ def data_to_model_domain (self , data : Table ) -> Table :
324+ """
325+ Transforms data to the model domain if possible.
326+
327+ Parameters
328+ ----------
329+ data
330+ Data to be transformed to the model domain
331+
332+ Returns
333+ -------
334+ Transformed data table
335+
336+ Raises
337+ ------
338+ DomainTransformationError
339+ Error indicates that transformation is not possible since domains
340+ are not compatible
341+ """
342+ if data .domain == self .domain :
343+ return data
344+
345+ if self .original_domain .attributes != data .domain .attributes \
346+ and data .X .size \
347+ and not all_nan (data .X ):
348+ new_data = data .transform (self .original_domain )
349+ if all_nan (new_data .X ):
350+ raise DomainTransformationError (
351+ "domain transformation produced no defined values" )
352+ return new_data .transform (self .domain )
353+
354+ return data .transform (self .domain )
355+
323356 def __call__ (self , data , ret = Value ):
324357 multitarget = len (self .domain .class_vars ) > 1
325358
@@ -336,21 +369,6 @@ def one_hot_probs(value):
336369 def fix_dim (x ):
337370 return x [0 ] if one_d else x
338371
339- def data_to_model_domain ():
340- if data .domain == self .domain :
341- return data
342-
343- if self .original_domain .attributes != data .domain .attributes \
344- and data .X .size \
345- and not all_nan (data .X ):
346- new_data = data .transform (self .original_domain )
347- if all_nan (new_data .X ):
348- raise DomainTransformationError (
349- "domain transformation produced no defined values" )
350- return new_data .transform (self .domain )
351-
352- return data .transform (self .domain )
353-
354372 if not 0 <= ret <= 2 :
355373 raise ValueError ("invalid value of argument 'ret'" )
356374 if ret > 0 and any (v .is_continuous for v in self .domain .class_vars ):
@@ -368,14 +386,18 @@ def data_to_model_domain():
368386 else :
369387 one_d = False
370388
389+ # if sparse convert to csr_matrix
390+ if scipy .sparse .issparse (data ):
391+ data = data .tocsr ()
392+
371393 # Call the predictor
372394 backmappers = None
373395 n_values = []
374396 if isinstance (data , (np .ndarray , scipy .sparse .csr .csr_matrix )):
375397 prediction = self .predict (data )
376398 elif isinstance (data , Table ):
377399 backmappers , n_values = self .get_backmappers (data )
378- data = data_to_model_domain ()
400+ data = self . data_to_model_domain (data )
379401 prediction = self .predict_storage (data )
380402 elif isinstance (data , (list , tuple )):
381403 data = Table .from_list (self .original_domain , data )
0 commit comments