@@ -61,7 +61,7 @@ def update(output: SparseObject,
6161 if accum is None or output ._obj is None :
6262 output .set_element (tensor .extract_element ())
6363 else :
64- output ._obj = impl .ewise_add (accum , output , tensor )._obj
64+ output ._obj = impl .ewise_add (output . dtype , accum , output , tensor )._obj
6565 return
6666
6767 if not isinstance (output , SparseTensor ):
@@ -80,7 +80,7 @@ def update(output: SparseObject,
8080 else : # mask=Y, accum=N, replace=N
8181 # Apply inverted mask, then eWiseAdd
8282 output ._replace (impl .select_by_mask (output , mask , desc_inverted ))
83- result = impl .ewise_add (BinaryOp .oneb , output , tensor )
83+ result = impl .ewise_add (output . dtype , BinaryOp .oneb , output , tensor )
8484 else :
8585 if mask is None : # mask=N, accum=N, replace=?, w/ indices
8686 # Drop indices in output, then eWiseAdd
@@ -106,22 +106,22 @@ def update(output: SparseObject,
106106 # Select the row/col indices in the mask, apply it inverted to the output, then eWiseAdd
107107 new_mask = impl .select_by_indices (mask , row_indices , col_indices )
108108 output ._replace (impl .select_by_mask (output , new_mask , desc_inverted ))
109- result = impl .ewise_add (BinaryOp .oneb , output , tensor )
109+ result = impl .ewise_add (output . dtype , BinaryOp .oneb , output , tensor )
110110 elif mask is None or not desc .replace :
111111 # eWiseAdd using accum
112- result = impl .ewise_add (accum , output , tensor )
112+ result = impl .ewise_add (output . dtype , accum , output , tensor )
113113 else :
114114 # Mask the output, then perform eWiseAdd using accum
115115 output ._replace (impl .select_by_mask (output , mask , desc ))
116- result = impl .ewise_add (accum , output , tensor )
116+ result = impl .ewise_add (output . dtype , accum , output , tensor )
117117
118118 if result is output :
119119 # This can happen if empty tensors are used as input
120120 return output
121121
122- # If not an intermediate result, make a copy
123- if not result ._intermediate_result :
124- result = impl .dup (result )
122+ # If not an intermediate result or wrong dtype , make a copy
123+ if not result ._intermediate_result or result . dtype != output . dtype :
124+ result = impl .dup (output . dtype , result )
125125
126126 output ._replace (result )
127127
@@ -132,10 +132,6 @@ def transpose(out: Matrix,
132132 mask : Optional [SparseTensor ] = None ,
133133 accum : Optional [BinaryOp ] = None ,
134134 desc : Descriptor = NULL_DESC ):
135- # Verify dtypes
136- if out .dtype != tensor .dtype :
137- raise GrbDomainMismatch (f"output type must be { tensor .dtype } , not { out .dtype } " )
138-
139135 # Apply descriptor transpose
140136 if tensor .ndims != 2 :
141137 raise TypeError (f"transpose requires Matrix, not { type (tensor )} " )
@@ -171,14 +167,6 @@ def ewise_add(out: SparseTensor,
171167 if type (op ) is not BinaryOp :
172168 raise TypeError (f"op must be BinaryOp, Monoid, or Semiring" )
173169
174- # Verify dtypes
175- if op .output is not None and type (op .output ) is not int :
176- raise GrbDomainMismatch ("op must return same type as inputs with ewise_add" )
177- if left .dtype != right .dtype :
178- raise GrbDomainMismatch (f"inputs must have same dtype: { left .dtype } != { right .dtype } " )
179- if out .dtype != left .dtype :
180- raise GrbDomainMismatch (f"output type must be { left .dtype } , not { out .dtype } " )
181-
182170 # Apply transposes
183171 if desc .transpose0 and left .ndims == 2 :
184172 left = TransposedMatrix .wrap (left )
@@ -195,7 +183,7 @@ def ewise_add(out: SparseTensor,
195183 left = impl .select_by_mask (left , mask , desc )
196184 right = impl .select_by_mask (right , mask , desc )
197185
198- result = impl .ewise_add (op , left , right )
186+ result = impl .ewise_add (out . dtype , op , left , right )
199187 update (out , result , mask , accum , desc )
200188
201189
@@ -214,13 +202,6 @@ def ewise_mult(out: SparseTensor,
214202 else :
215203 raise TypeError (f"op must be BinaryOp, Monoid, or Semiring" )
216204
217- # Verify dtypes
218- if left .dtype != right .dtype :
219- raise GrbDomainMismatch (f"inputs must have same dtype: { left .dtype } != { right .dtype } " )
220- required_out_dtype = op .get_output_type (left .dtype , right .dtype )
221- if out .dtype != required_out_dtype :
222- raise GrbDomainMismatch (f"output type must be { required_out_dtype } , not { out .dtype } " )
223-
224205 # Apply transposes
225206 if desc .transpose0 and left .ndims == 2 :
226207 left = TransposedMatrix .wrap (left )
@@ -237,7 +218,7 @@ def ewise_mult(out: SparseTensor,
237218 # Only need to apply mask to one of the inputs
238219 left = impl .select_by_mask (left , mask , desc )
239220
240- result = impl .ewise_mult (op , left , right )
221+ result = impl .ewise_mult (out . dtype , op , left , right )
241222 update (out , result , mask , accum , desc )
242223
243224
@@ -253,13 +234,6 @@ def mxm(out: Matrix,
253234 if type (op ) is not Semiring :
254235 raise TypeError (f"op must be Semiring, not { type (op )} " )
255236
256- # Verify dtypes
257- if left .dtype != right .dtype :
258- raise GrbDomainMismatch (f"inputs must have same dtype: { left .dtype } != { right .dtype } " )
259- required_out_dtype = op .binop .get_output_type (left .dtype , right .dtype )
260- if out .dtype != required_out_dtype :
261- raise GrbDomainMismatch (f"output type must be { required_out_dtype } , not { out .dtype } " )
262-
263237 # Apply transposes
264238 if left .ndims != right .ndims != 2 :
265239 raise GrbDimensionMismatch ("mxm requires rank 2 tensors" )
@@ -285,7 +259,7 @@ def mxm(out: Matrix,
285259 right = impl .flip_layout (right )
286260
287261 # TODO: apply the mask during the computation, not at the end
288- result = impl .mxm (op , left , right )
262+ result = impl .mxm (out . dtype , op , left , right )
289263 if mask is not None :
290264 result = impl .select_by_mask (result , mask , desc )
291265 update (out , result , mask , accum , desc )
@@ -303,13 +277,6 @@ def mxv(out: Vector,
303277 if type (op ) is not Semiring :
304278 raise TypeError (f"op must be Semiring, not { type (op )} " )
305279
306- # Verify dtypes
307- if left .dtype != right .dtype :
308- raise GrbDomainMismatch (f"inputs must have same dtype: { left .dtype } != { right .dtype } " )
309- required_out_dtype = op .binop .get_output_type (left .dtype , right .dtype )
310- if out .dtype != required_out_dtype :
311- raise GrbDomainMismatch (f"output type must be { required_out_dtype } , not { out .dtype } " )
312-
313280 # Apply transpose
314281 if left .ndims != 2 :
315282 raise GrbDimensionMismatch ("mxv requires matrix as first input" )
@@ -325,7 +292,7 @@ def mxv(out: Vector,
325292 raise GrbDimensionMismatch (f"output size should be { left .shape [0 ]} not { out .shape [0 ]} " )
326293
327294 # TODO: apply the mask during the computation, not at the end
328- result = impl .mxv (op , left , right )
295+ result = impl .mxv (out . dtype , op , left , right )
329296 if mask is not None :
330297 result = impl .select_by_mask (result , mask , desc )
331298 update (out , result , mask , accum , desc )
@@ -343,13 +310,6 @@ def vxm(out: Vector,
343310 if type (op ) is not Semiring :
344311 raise TypeError (f"op must be Semiring, not { type (op )} " )
345312
346- # Verify dtypes
347- if left .dtype != right .dtype :
348- raise GrbDomainMismatch (f"inputs must have same dtype: { left .dtype } != { right .dtype } " )
349- required_out_dtype = op .binop .get_output_type (left .dtype , right .dtype )
350- if out .dtype != required_out_dtype :
351- raise GrbDomainMismatch (f"output type must be { required_out_dtype } , not { out .dtype } " )
352-
353313 # Apply transpose
354314 if right .ndims != 2 :
355315 raise GrbDimensionMismatch ("vxm requires matrix as second input" )
@@ -365,7 +325,7 @@ def vxm(out: Vector,
365325 raise GrbDimensionMismatch (f"output size should be { right .shape [1 ]} not { out .shape [0 ]} " )
366326
367327 # TODO: apply the mask during the computation, not at the end
368- result = impl .vxm (op , left , right )
328+ result = impl .vxm (out . dtype , op , left , right )
369329 if mask is not None :
370330 result = impl .select_by_mask (result , mask , desc )
371331 update (out , result , mask , accum , desc )
@@ -386,7 +346,6 @@ def apply(out: SparseTensor,
386346 if optype is UnaryOp :
387347 if thunk is not None or left is not None or right is not None :
388348 raise TypeError ("UnaryOp does not accept thunk, left, or right" )
389- required_out_dtype = op .get_output_type (tensor .dtype )
390349 elif optype is BinaryOp :
391350 if thunk is not None :
392351 raise TypeError ("BinaryOp accepts left or thing, not thunk" )
@@ -396,23 +355,16 @@ def apply(out: SparseTensor,
396355 raise TypeError ("Cannot provide both left and right" )
397356 if left is not None :
398357 left = ensure_scalar_of_type (left , tensor .dtype )
399- required_out_dtype = op .get_output_type (left .dtype , tensor .dtype )
400358 else :
401359 right = ensure_scalar_of_type (right , tensor .dtype )
402- required_out_dtype = op .get_output_type (tensor .dtype , right .dtype )
403360 elif optype is IndexUnaryOp :
404361 if left is not None or right is not None :
405362 raise TypeError ("IndexUnaryOp accepts thunk, not left or right" )
406363 thunk_dtype = INT64 if op .thunk_as_index else tensor .dtype
407364 thunk = ensure_scalar_of_type (thunk , thunk_dtype )
408- required_out_dtype = op .get_output_type (tensor .dtype , thunk .dtype )
409365 else :
410366 raise TypeError (f"op must be UnaryOp, BinaryOp, or IndexUnaryOp, not { type (op )} " )
411367
412- # Verify dtype
413- if out .dtype != required_out_dtype :
414- raise GrbDomainMismatch (f"output type must be { required_out_dtype } , not { out .dtype } " )
415-
416368 # Apply transpose
417369 if desc .transpose0 and tensor .ndims == 2 :
418370 tensor = TransposedMatrix .wrap (tensor )
@@ -433,9 +385,9 @@ def apply(out: SparseTensor,
433385 and desc is NULL_DESC
434386 and not tensor ._intermediate_result
435387 ):
436- impl .apply (op , tensor , left , right , None , inplace = True )
388+ impl .apply (out . dtype , op , tensor , left , right , None , inplace = True )
437389 else :
438- result = impl .apply (op , tensor , left , right , thunk )
390+ result = impl .apply (out . dtype , op , tensor , left , right , thunk )
439391 update (out , result , mask , accum , desc )
440392
441393
@@ -452,8 +404,6 @@ def select(out: SparseTensor,
452404 raise TypeError (f"op must be SelectOp, not { type (op )} " )
453405
454406 # Verify dtypes
455- if out .dtype != tensor .dtype :
456- raise GrbDomainMismatch (f"output dtype must match input dtype: { out .dtype } != { tensor .dtype } " )
457407 thunk_dtype = INT64 if op .thunk_as_index else tensor .dtype
458408 thunk = ensure_scalar_of_type (thunk , thunk_dtype )
459409
@@ -468,7 +418,7 @@ def select(out: SparseTensor,
468418 if mask is not None :
469419 tensor = impl .select_by_mask (tensor , mask , desc )
470420
471- result = impl .select (op , tensor , thunk )
421+ result = impl .select (out . dtype , op , tensor , thunk )
472422 update (out , result , mask , accum , desc )
473423
474424
@@ -483,10 +433,6 @@ def reduce_to_vector(out: Vector,
483433 if type (op ) is not Monoid :
484434 raise TypeError (f"op must be Monoid, not { type (op )} " )
485435
486- # Verify dtypes
487- if out .dtype != tensor .dtype :
488- raise GrbDomainMismatch (f"output dtype must match input dtype: { out .dtype } != { tensor .dtype } " )
489-
490436 # Apply transpose
491437 if tensor .ndims != 2 :
492438 raise GrbDimensionMismatch ("reduce_to_vector requires matrix input" )
@@ -500,7 +446,7 @@ def reduce_to_vector(out: Vector,
500446 raise GrbDimensionMismatch (f"output size should be { tensor .shape [0 ]} not { out .shape [0 ]} " )
501447
502448 # TODO: apply the mask during the computation, not at the end
503- result = impl .reduce_to_vector (op , tensor )
449+ result = impl .reduce_to_vector (out . dtype , op , tensor )
504450 if mask is not None :
505451 result = impl .select_by_mask (result , mask , desc )
506452 update (out , result , mask , accum , desc )
@@ -516,15 +462,11 @@ def reduce_to_scalar(out: Scalar,
516462 if type (op ) is not Monoid :
517463 raise TypeError (f"op must be Monoid, not { type (op )} " )
518464
519- # Verify dtypes
520- if out .dtype != tensor .dtype :
521- raise GrbDomainMismatch (f"output dtype must match input dtype: { out .dtype } != { tensor .dtype } " )
522-
523465 # Compare shapes
524466 if out .ndims != 0 :
525467 raise GrbDimensionMismatch ("reduce_to_scalar requires scalar output" )
526468
527- result = impl .reduce_to_scalar (op , tensor )
469+ result = impl .reduce_to_scalar (out . dtype , op , tensor )
528470 update (out , result , accum = accum , desc = desc )
529471
530472
@@ -539,10 +481,6 @@ def extract(out: SparseTensor,
539481 """
540482 Setting row_indices or col_indices to `None` is the equivalent of GrB_ALL
541483 """
542- # Verify dtypes
543- if out .dtype != tensor .dtype :
544- raise GrbDomainMismatch (f"output must have same dtype as input: { out .dtype } != { tensor .dtype } " )
545-
546484 # Apply transpose
547485 if desc .transpose0 and tensor .ndims == 2 :
548486 tensor = TransposedMatrix .wrap (tensor )
@@ -587,7 +525,7 @@ def extract(out: SparseTensor,
587525 if out .shape != expected_out_shape :
588526 raise GrbDimensionMismatch (f"output shape mismatch: { out .shape } != { expected_out_shape } " )
589527
590- result = impl .extract (tensor , row_indices , col_indices , row_size , col_size )
528+ result = impl .extract (out . dtype , tensor , row_indices , col_indices , row_size , col_size )
591529 if mask is not None :
592530 result = impl .select_by_mask (result , mask , desc )
593531 update (out , result , mask , accum , desc )
@@ -610,10 +548,6 @@ def assign(out: SparseTensor,
610548 raise TypeError (f"tensor must be a SparseObject or Python scalar, not { type (tensor )} " )
611549 tensor = ensure_scalar_of_type (tensor , out .dtype )
612550
613- # Verify dtypes
614- if out .dtype != tensor .dtype :
615- raise GrbDomainMismatch (f"output must have same dtype as input: { out .dtype } != { tensor .dtype } " )
616-
617551 # Apply transpose
618552 if desc .transpose0 and tensor .ndims == 2 :
619553 tensor = TransposedMatrix .wrap (tensor )
@@ -653,7 +587,7 @@ def assign(out: SparseTensor,
653587 if mask is None :
654588 raise GrbError ("This will create a dense matrix. Please provide a mask or indices." )
655589 # Use mask to build an iso-valued Matrix
656- result = impl .apply (BinaryOp .second , mask , right = tensor )
590+ result = impl .apply (out . dtype , BinaryOp .second , mask , right = tensor )
657591 else :
658592 if out .ndims == 1 : # Vector output
659593 result = impl .build_iso_vector_from_indices (out .dtype , * out .shape , row_indices , tensor )
@@ -675,7 +609,7 @@ def assign(out: SparseTensor,
675609 if tensor .shape != expected_input_shape :
676610 raise GrbDimensionMismatch (f"input shape mismatch: { tensor .shape } != { expected_input_shape } " )
677611
678- result = impl .assign (tensor , row_indices , col_indices , * out .shape )
612+ result = impl .assign (out . dtype , tensor , row_indices , col_indices , * out .shape )
679613 if mask is not None :
680614 result = impl .select_by_mask (result , mask , desc )
681615 update (out , result , mask , accum , desc , row_indices = row_indices , col_indices = col_indices )
0 commit comments