6161from pyop2 .utils import cached_property , get_petsc_dir
6262from pyop2 .configuration import configuration
6363from pyop2 .codegen .rep2loopy import _PreambleGen
64+ from pyop2 .datatypes import ScalarType
6465
6566
6667def vectorise (wrapper , iname , batch_size ):
@@ -75,8 +76,8 @@ def vectorise(wrapper, iname, batch_size):
7576 # create constant zero vectors
7677 wrapper = wrapper .copy (target = loopy .CVecTarget ())
7778 kernel = wrapper .root_kernel
78- zeros = loopy .TemporaryVariable ("_zeros" , shape = loopy .auto , dtype = numpy . float64 , read_only = True ,
79- initializer = numpy .array (0.0 , dtype = numpy . float64 ),
79+ zeros = loopy .TemporaryVariable ("_zeros" , shape = loopy .auto , dtype = ScalarType , read_only = True ,
80+ initializer = numpy .array (0.0 , dtype = ScalarType ),
8081 address_space = loopy .AddressSpace .GLOBAL , zero_size = batch_size )
8182 tmps = kernel .temporary_variables .copy ()
8283 tmps ["_zeros" ] = zeros
@@ -103,9 +104,12 @@ def vectorise(wrapper, iname, batch_size):
103104 wrapper = wrapper .with_root_kernel (kernel )
104105
105106 # vector data type
106- vec_types = [("double" , 8 ), ("int" , 4 )] # scalar type, bytes
107- preamble = ["typedef {0} {0}{1} __attribute__ ((vector_size ({2})));" .format (t , batch_size , batch_size * b ) for t , b in vec_types ]
108- preamble = "\n " + "\n " .join (preamble )
107+ dw_typedef = ["typedef signed char SC;" ]
108+ preamble_dw = "\n " + "\n " .join (dw_typedef )
109+ vec_types = [("double" , 8 ), ("int" , 4 ), ("SC" , 1 )] # scalar type, bytes
110+ vec_typedef = ["typedef {0} {0}{1} __attribute__ ((vector_size ({2})));" .format (t , batch_size , batch_size * b ) for t , b in vec_types ]
111+ preamble_vec = "\n " + "\n " .join (vec_typedef )
112+ preamble = preamble_dw + preamble_vec
109113
110114 wrapper = loopy .register_preamble_generators (wrapper , [_PreambleGen (preamble , idx = "01" )])
111115 return wrapper
0 commit comments