1616from .operators import UnaryOp , BinaryOp , SelectOp , IndexUnaryOp , Monoid , Semiring
1717from .compiler import compile , engine_cache
1818from . descriptor import Descriptor , NULL as NULL_DESC
19- from .utils import get_sparse_output_pointer , get_scalar_output_pointer , pick_and_renumber_indices
19+ from .utils import (get_sparse_output_pointer , get_scalar_output_pointer ,
20+ get_scalar_input_arg , pick_and_renumber_indices )
2021from .types import RankedTensorType , BOOL , INT64 , FP64
21- from .exceptions import GrbIndexOutOfBounds , GrbDimensionMismatch
22+ from .exceptions import GrbError , GrbIndexOutOfBounds , GrbDimensionMismatch
2223
2324
2425# TODO: vec->matrix broadcasting as builtin param in select_by_mask (rowwise/colwise)
@@ -49,8 +50,7 @@ def select_by_mask(sp: SparseTensorBase, mask: SparseTensor, desc: Descriptor =
4950
5051 # Convert value mask to structural mask
5152 if not desc .mask_structure :
52- zero = Scalar .new (mask .dtype )
53- zero .set_element (0 )
53+ zero = Scalar .new (mask .dtype , 0 )
5454 mask = select (SelectOp .valuene , mask , thunk = zero )
5555
5656 # Build and compile if needed
@@ -292,6 +292,22 @@ def main(x):
292292 return compile (module )
293293
294294
295+ def _build_scalar_binop (op : BinaryOp , left : Scalar , right : Scalar ):
296+ # Both scalars are present
297+ with ir .Context (), ir .Location .unknown ():
298+ module = ir .Module .create ()
299+ with ir .InsertionPoint (module .body ):
300+ dtype = left .dtype .build_mlir_type ()
301+
302+ @func .FuncOp .from_py_func (dtype , dtype )
303+ def main (x , y ):
304+ result = op (x , y )
305+ return result
306+ main .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
307+
308+ return compile (module )
309+
310+
295311def ewise_add (op : BinaryOp , left : SparseTensorBase , right : SparseTensorBase ):
296312 assert left .ndims == right .ndims
297313 assert left .dtype == right .dtype
@@ -301,12 +317,17 @@ def ewise_add(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
301317 if right ._obj is None :
302318 return left
303319
304- assert left ._sparsity == right ._sparsity
305-
306320 rank = left .ndims
307321 if rank == 0 : # Scalar
308- # TODO: implement this
309- raise NotImplementedError ("doesn't yet work for Scalar" )
322+ key = ('scalar_binop' , op .name , left .dtype , right .dtype )
323+ if key not in engine_cache :
324+ engine_cache [key ] = _build_scalar_binop (op , left , right )
325+ mem_out = get_scalar_output_pointer (left .dtype )
326+ arg_pointers = [get_scalar_input_arg (left ), get_scalar_input_arg (right ), mem_out ]
327+ engine_cache [key ].invoke ('main' , * arg_pointers )
328+ return Scalar (left .dtype , (), left .dtype .np_type (mem_out .contents .value ))
329+
330+ assert left ._sparsity == right ._sparsity
310331
311332 # Build and compile if needed
312333 key = ('ewise_add' , op .name , * left .get_loop_key (), * right .get_loop_key ())
@@ -366,18 +387,22 @@ def main(x, y):
366387def ewise_mult (op : BinaryOp , left : SparseTensorBase , right : SparseTensorBase ):
367388 assert left .ndims == right .ndims
368389 assert left .dtype == right .dtype
390+ output_dtype = op .get_output_type (left .dtype , right .dtype )
369391
370- if left ._obj is None :
371- return left
372- if right ._obj is None :
373- return right
374-
375- assert left ._sparsity == right ._sparsity
392+ if left ._obj is None or right ._obj is None :
393+ return left .baseclass (output_dtype , left .shape )
376394
377395 rank = left .ndims
378396 if rank == 0 : # Scalar
379- # TODO: implement this
380- raise NotImplementedError ("doesn't yet work for Scalar" )
397+ key = ('scalar_binop' , op .name , left .dtype , right .dtype )
398+ if key not in engine_cache :
399+ engine_cache [key ] = _build_scalar_binop (op , left , right )
400+ mem_out = get_scalar_output_pointer (output_dtype )
401+ arg_pointers = [get_scalar_input_arg (left ), get_scalar_input_arg (right ), mem_out ]
402+ engine_cache [key ].invoke ('main' , * arg_pointers )
403+ return Scalar (output_dtype , (), output_dtype .np_type (mem_out .contents .value ))
404+
405+ assert left ._sparsity == right ._sparsity
381406
382407 # Build and compile if needed
383408 key = ('ewise_mult' , op .name , * left .get_loop_key (), * right .get_loop_key ())
@@ -388,7 +413,7 @@ def ewise_mult(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
388413 mem_out = get_sparse_output_pointer ()
389414 arg_pointers = [left ._obj , right ._obj , mem_out ]
390415 engine_cache [key ].invoke ('main' , * arg_pointers )
391- return left .baseclass (op . get_output_type ( left . dtype , right . dtype ) , left .shape , mem_out ,
416+ return left .baseclass (output_dtype , left .shape , mem_out ,
392417 left ._sparsity , left .perceived_ordering , intermediate_result = True )
393418
394419
@@ -671,11 +696,6 @@ def apply(op: Union[UnaryOp, BinaryOp, IndexUnaryOp],
671696 right : Optional [Scalar ] = None ,
672697 thunk : Optional [Scalar ] = None ,
673698 inplace : bool = False ):
674- rank = sp .ndims
675- if rank == 0 : # Scalar
676- # TODO: implement this
677- raise NotImplementedError ("doesn't yet work for Scalar" )
678-
679699 # Find output dtype
680700 optype = type (op )
681701 if optype is UnaryOp :
@@ -693,6 +713,25 @@ def apply(op: Union[UnaryOp, BinaryOp, IndexUnaryOp],
693713 if sp ._obj is None :
694714 return sp .baseclass (output_dtype , sp .shape )
695715
716+ rank = sp .ndims
717+ if rank == 0 : # Scalar
718+ if optype is UnaryOp :
719+ key = ('scalar_apply_unary' , op .name , sp .dtype )
720+ elif optype is BinaryOp :
721+ if left is not None :
722+ key = ('scalar_apply_bind_first' , op .name , sp .dtype , left ._obj )
723+ else :
724+ key = ('scalar_apply_bind_second' , op .name , sp .dtype , right ._obj )
725+ else :
726+ raise GrbError ("apply scalar not supported for IndexUnaryOp" )
727+
728+ if key not in engine_cache :
729+ engine_cache [key ] = _build_scalar_apply (op , sp , left , right )
730+ mem_out = get_scalar_output_pointer (output_dtype )
731+ arg_pointers = [get_scalar_input_arg (sp ), mem_out ]
732+ engine_cache [key ].invoke ('main' , * arg_pointers )
733+ return Scalar .new (output_dtype , mem_out .contents .value )
734+
696735 # Build and compile if needed
697736 # Note that Scalars are included in the key because they are inlined in the compiled code
698737 if optype is UnaryOp :
@@ -721,6 +760,33 @@ def apply(op: Union[UnaryOp, BinaryOp, IndexUnaryOp],
721760 sp ._sparsity , sp .perceived_ordering , intermediate_result = True )
722761
723762
763+ def _build_scalar_apply (op : Union [UnaryOp , BinaryOp ],
764+ sp : SparseTensorBase ,
765+ left : Optional [Scalar ],
766+ right : Optional [Scalar ]):
767+ optype = type (op )
768+ with ir .Context (), ir .Location .unknown ():
769+ module = ir .Module .create ()
770+ with ir .InsertionPoint (module .body ):
771+ dtype = sp .dtype .build_mlir_type ()
772+
773+ @func .FuncOp .from_py_func (dtype )
774+ def main (x ):
775+ if optype is BinaryOp :
776+ if left is not None :
777+ left_val = arith .ConstantOp (left .dtype .build_mlir_type (), left .extract_element ())
778+ result = op (left_val , x )
779+ else :
780+ right_val = arith .ConstantOp (right .dtype .build_mlir_type (), right .extract_element ())
781+ result = op (x , right_val )
782+ else :
783+ result = op (x )
784+ return result
785+ main .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
786+
787+ return compile (module )
788+
789+
724790def _build_apply (op : Union [UnaryOp , BinaryOp , IndexUnaryOp ],
725791 sp : SparseTensorBase ,
726792 left : Optional [Scalar ],
@@ -768,16 +834,16 @@ def main(x):
768834 arg0 , = present .arguments
769835 if optype is IndexUnaryOp :
770836 if op .thunk_as_index :
771- thunk_val = arith .ConstantOp (index , thunk ._obj . item ())
837+ thunk_val = arith .ConstantOp (index , thunk .extract_element ())
772838 else :
773- thunk_val = arith .ConstantOp (thunk .dtype .build_mlir_type (), thunk ._obj . item ())
839+ thunk_val = arith .ConstantOp (thunk .dtype .build_mlir_type (), thunk .extract_element ())
774840 val = op (arg0 , rowidx , colidx , thunk_val )
775841 elif optype is BinaryOp :
776842 if left is not None :
777- left_val = arith .ConstantOp (left .dtype .build_mlir_type (), left ._obj . item ())
843+ left_val = arith .ConstantOp (left .dtype .build_mlir_type (), left .extract_element ())
778844 val = op (left_val , arg0 )
779845 else :
780- right_val = arith .ConstantOp (right .dtype .build_mlir_type (), right ._obj . item ())
846+ right_val = arith .ConstantOp (right .dtype .build_mlir_type (), right .extract_element ())
781847 val = op (arg0 , right_val )
782848 else :
783849 val = op (arg0 )
@@ -818,10 +884,10 @@ def main(x):
818884 val = memref .LoadOp (vals , [x ])
819885 if optype is BinaryOp :
820886 if left is not None :
821- left_val = arith .ConstantOp (left .dtype .build_mlir_type (), left ._obj . item ())
887+ left_val = arith .ConstantOp (left .dtype .build_mlir_type (), left .extract_element ())
822888 result = op (left_val , val )
823889 else :
824- right_val = arith .ConstantOp (right .dtype .build_mlir_type (), right ._obj . item ())
890+ right_val = arith .ConstantOp (right .dtype .build_mlir_type (), right .extract_element ())
825891 result = op (val , right_val )
826892 else :
827893 result = op (val )
@@ -833,15 +899,24 @@ def main(x):
833899
834900
835901def select (op : SelectOp , sp : SparseTensor , thunk : Scalar ):
836- rank = sp .ndims
837- if rank == 0 : # Scalar
838- # TODO: implement this
839- raise NotImplementedError ("doesn't yet work for Scalar" )
840-
841902 # Handle case of empty tensor
842903 if sp ._obj is None :
843904 return sp .__class__ (sp .dtype , sp .shape )
844905
906+ rank = sp .ndims
907+ if rank == 0 : # Scalar
908+ key = ('scalar_select' , op .name , sp .dtype , thunk ._obj )
909+ if key not in engine_cache :
910+ engine_cache [key ] = _build_scalar_select (op , sp , thunk )
911+ mem_out = get_scalar_output_pointer (sp .dtype )
912+ arg_pointers = [get_scalar_input_arg (sp ), mem_out ]
913+ engine_cache [key ].invoke ('main' , * arg_pointers )
914+ # Invocation returns True/False for whether to keep value
915+ if mem_out .contents .value :
916+ return sp .dup ()
917+ else :
918+ return Scalar .new (sp .dtype )
919+
845920 # Build and compile if needed
846921 # Note that thunk is included in the key because it is inlined in the compiled code
847922 key = ('select' , op .name , * sp .get_loop_key (), thunk ._obj )
@@ -856,6 +931,27 @@ def select(op: SelectOp, sp: SparseTensor, thunk: Scalar):
856931 sp ._sparsity , sp .perceived_ordering , intermediate_result = True )
857932
858933
934+ def _build_scalar_select (op : SelectOp , sp : SparseTensorBase , thunk : Scalar ):
935+ with ir .Context (), ir .Location .unknown ():
936+ module = ir .Module .create ()
937+ with ir .InsertionPoint (module .body ):
938+ index = ir .IndexType .get ()
939+ dtype = sp .dtype .build_mlir_type ()
940+
941+ @func .FuncOp .from_py_func (dtype )
942+ def main (x ):
943+ c0 = arith .ConstantOp (index , 0 )
944+ if op .thunk_as_index :
945+ thunk_val = arith .ConstantOp (index , thunk .extract_element ())
946+ else :
947+ thunk_val = arith .ConstantOp (thunk .dtype .build_mlir_type (), thunk .extract_element ())
948+ cmp = op (x , c0 , c0 , thunk_val )
949+ return cmp
950+ main .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
951+
952+ return compile (module )
953+
954+
859955def _build_select (op : SelectOp , sp : SparseTensorBase , thunk : Scalar ):
860956 with ir .Context (), ir .Location .unknown ():
861957 module = ir .Module .create ()
@@ -894,9 +990,9 @@ def main(x):
894990 with ir .InsertionPoint (region ):
895991 arg0 , = region .arguments
896992 if op .thunk_as_index :
897- thunk_val = arith .ConstantOp (index , thunk ._obj . item ())
993+ thunk_val = arith .ConstantOp (index , thunk .extract_element ())
898994 else :
899- thunk_val = arith .ConstantOp (thunk .dtype .build_mlir_type (), thunk ._obj . item ())
995+ thunk_val = arith .ConstantOp (thunk .dtype .build_mlir_type (), thunk .extract_element ())
900996 cmp = op (arg0 , rowidx , colidx , thunk_val )
901997 sparse_tensor .YieldOp (result = cmp )
902998 linalg .YieldOp ([res ])
@@ -977,9 +1073,7 @@ def reduce_to_scalar(op: Monoid, sp: SparseTensorBase):
9771073 mem_out = get_scalar_output_pointer (sp .dtype )
9781074 arg_pointers = [sp ._obj , mem_out ]
9791075 engine_cache [key ].invoke ('main' , * arg_pointers )
980- s = Scalar .new (sp .dtype )
981- s .set_element (mem_out .contents .value )
982- return s
1076+ return Scalar .new (sp .dtype , mem_out .contents .value )
9831077
9841078
9851079def _build_reduce_to_scalar (op : Monoid , sp : SparseTensorBase ):
0 commit comments