@@ -145,7 +145,9 @@ def write_nonlead(self):
145145 return f'{ self ._nonlead } '
146146
147147 def write (self , context : Context ):
148- if self ._block > 1 :
148+ if context .get_vm ().get_lexic ().simd_mode :
149+ return f'({ self ._nonlead } * { self ._block } )'
150+ elif self ._block > 1 :
149151 return f'(({ context .get_vm ().get_lexic ().thread_idx_x } / { self ._stride } ) % { self ._block } ) + { self ._nonlead } * { self ._block } '
150152 elif self ._block == 1 :
151153 return f'{ self ._nonlead } '
@@ -443,11 +445,14 @@ def encode_values(self, pos, runIdx, writer, context: Context, variable, index:
443445 return wrote
444446
445447 def load_linear (self , writer , context : Context , variable , index ):
446- if self . stype == SymbolType . Register :
447- access = f'{ self .name } [ { index // 32 } ]' # TODO
448+ if context . get_vm (). get_lexic (). simd_mode :
449+ writer ( f'{ context . get_vm (). get_lexic (). simd ( self .get_fptype (), 16 ) } { variable } ( { index } );' )
448450 else :
449- access = f'{ self .name } [{ index } + threadIdx.x]'
450- writer (f'{ self .get_fptype ()} { variable } = { access } ;' )
451+ if self .stype == SymbolType .Register :
452+ access = f'{ self .name } [{ index // 32 } ]' # TODO
453+ else :
454+ access = f'{ self .name } [{ index } + threadIdx.x]'
455+ writer (f'{ self .get_fptype ()} { variable } = { access } ;' )
451456
452457 def load (self , writer , context : Context , variable , index : List [Union [str , int , Immediate , Variable , LeadIndex ]], nontemp ):
453458 if self .stype == SymbolType .Data or (not self .obj .is_dense () and not isinstance (self .obj .spp , BoundingBoxSPP )):
@@ -473,7 +478,9 @@ def load(self, writer, context: Context, variable, index: List[Union[str, int, I
473478 if self .stype == SymbolType .Register or self .stype == SymbolType .Scratch :
474479 assert len (self .lead_dims ) == 1
475480 idx = index [self .lead_dims [0 ]]
476- if not idx .is_thread_dependent ():
481+ if isinstance (idx , (float , int , np .int32 )) or not idx .is_thread_dependent ():
482+ if isinstance (idx , (float , int , np .int32 )):
483+ idx = Immediate (idx , Datatype .I32 )
477484 # doesn't work
478485 if isinstance (idx , Variable ):
479486 writevar = idx .write_nonlead ()
@@ -490,7 +497,9 @@ def load(self, writer, context: Context, variable, index: List[Union[str, int, I
490497 access = pre_access
491498 else :
492499 access = pre_access
493- if self .stype == SymbolType .Global :
500+ if context .get_vm ().get_lexic ().simd_mode :
501+ writer (f'{ context .get_vm ().get_lexic ().simd (self .get_fptype (), 16 )} { variable } ({ access } );' )
502+ elif self .stype == SymbolType .Global :
494503 writer (f'{ self .get_fptype ()} { variable } ;' )
495504 writer (context .get_vm ().get_lexic ().glb_load (variable , access , nontemp ))
496505 else :
@@ -502,19 +511,26 @@ def store(self, writer, context, variable, index: List[Union[str, int, Immediate
502511
503512 access = self .access (context , index )
504513
505- if self .stype == SymbolType .Global :
506- assign = context .get_vm ().get_lexic ().glb_store (access , variable , nontemp )
514+ if context .get_vm ().get_lexic ().simd_mode :
515+ if self .stype == SymbolType .Global :
516+ writer (f'{ variable } .copy_to({ access } );' )
517+ else :
518+ writer (f'{ variable } = { access } ;' )
507519 else :
508- assign = f'{ access } = { variable } ;'
509- if self .stype == SymbolType .Register or self .stype == SymbolType .Scratch :
510- assert len (self .lead_dims ) == 1
511- if isinstance (index [self .lead_dims [0 ]], LeadIndex ):
512- writer (assign )
520+ if self .stype == SymbolType .Global :
521+ assign = context .get_vm ().get_lexic ().glb_store (access , variable , nontemp )
513522 else :
514- with writer .If (f'{ context .get_vm ().get_lexic ().thread_idx_x } == { index [self .lead_dims [0 ]]} ' ):
523+ assign = f'{ access } = { variable } ;'
524+
525+ if self .stype == SymbolType .Register or self .stype == SymbolType .Scratch :
526+ assert len (self .lead_dims ) == 1
527+ if isinstance (index [self .lead_dims [0 ]], LeadIndex ):
515528 writer (assign )
516- else :
517- writer (assign )
529+ else :
530+ with writer .If (f'{ context .get_vm ().get_lexic ().thread_idx_x } == { index [self .lead_dims [0 ]]} ' ):
531+ writer (assign )
532+ else :
533+ writer (assign )
518534
519535 def add_user (self , user ):
520536 self ._users .append (user )
0 commit comments