55
66import numpy
77from loopy .types import OpaqueType
8+ from pyop2 .global_kernel import (GlobalKernelArg , DatKernelArg , MixedDatKernelArg ,
9+ MatKernelArg , MixedMatKernelArg , PermutedMapKernelArg )
810from pyop2 .codegen .representation import (Accumulate , Argument , Comparison ,
911 DummyInstruction , Extent , FixedIndex ,
1012 FunctionCall , Index , Indexed ,
1618 When , Zero )
1719from pyop2 .datatypes import IntType
1820from pyop2 .op2 import (ALL , INC , MAX , MIN , ON_BOTTOM , ON_INTERIOR_FACETS ,
19- ON_TOP , READ , RW , WRITE , Subset , PermutedMap )
21+ ON_TOP , READ , RW , WRITE )
2022from pyop2 .utils import cached_property
2123
2224
@@ -32,18 +34,22 @@ class Map(object):
3234 "variable" , "unroll" , "layer_bounds" ,
3335 "prefetch" , "_pmap_count" )
3436
35- def __init__ (self , map_ , interior_horizontal , layer_bounds ,
36- offset = None , unroll = False ):
37- self .variable = map_ .iterset ._extruded and not map_ .iterset .constant_layers
37+ def __init__ (self , interior_horizontal , layer_bounds ,
38+ arity , dtype ,
39+ offset = None , unroll = False ,
40+ extruded = False , constant_layers = False ):
41+ self .variable = extruded and not constant_layers
3842 self .unroll = unroll
3943 self .layer_bounds = layer_bounds
4044 self .interior_horizontal = interior_horizontal
4145 self .prefetch = {}
42- offset = map_ . offset
43- shape = (None , ) + map_ . shape [ 1 :]
44- values = Argument (shape , dtype = map_ . dtype , pfx = "map" )
46+
47+ shape = (None , arity )
48+ values = Argument (shape , dtype = dtype , pfx = "map" )
4549 if offset is not None :
46- if len (set (map_ .offset )) == 1 :
50+ assert type (offset ) == tuple
51+ offset = numpy .array (offset , dtype = numpy .int32 )
52+ if len (set (offset )) == 1 :
4753 offset = Literal (offset [0 ], casting = True )
4854 else :
4955 offset = NamedLiteral (offset , parent = values , suffix = "offset" )
@@ -616,15 +622,18 @@ def emit_unpack_instruction(self, *,
616622
617623class WrapperBuilder (object ):
618624
619- def __init__ (self , * , kernel , iterset , iteration_region = None , single_cell = False ,
625+ def __init__ (self , * , kernel , subset , extruded , constant_layers , iteration_region = None , single_cell = False ,
620626 pass_layer_to_kernel = False , forward_arg_types = ()):
621627 self .kernel = kernel
628+ self .local_knl_args = iter (kernel .arguments )
622629 self .arguments = []
623630 self .argument_accesses = []
624631 self .packed_args = []
625632 self .indices = []
626633 self .maps = OrderedDict ()
627- self .iterset = iterset
634+ self .subset = subset
635+ self .extruded = extruded
636+ self .constant_layers = constant_layers
628637 if iteration_region is None :
629638 self .iteration_region = ALL
630639 else :
@@ -637,18 +646,6 @@ def __init__(self, *, kernel, iterset, iteration_region=None, single_cell=False,
637646 def requires_zeroed_output_arguments (self ):
638647 return self .kernel .requires_zeroed_output_arguments
639648
640- @property
641- def subset (self ):
642- return isinstance (self .iterset , Subset )
643-
644- @property
645- def extruded (self ):
646- return self .iterset ._extruded
647-
648- @property
649- def constant_layers (self ):
650- return self .extruded and self .iterset .constant_layers
651-
652649 @cached_property
653650 def loop_extents (self ):
654651 return (Argument ((), IntType , name = "start" ),
@@ -753,94 +750,98 @@ def loop_indices(self):
753750 return (self .loop_index , None , self ._loop_index )
754751
755752 def add_argument (self , arg ):
753+ local_arg = next (self .local_knl_args )
754+ access = local_arg .access
755+ dtype = local_arg .dtype
756756 interior_horizontal = self .iteration_region == ON_INTERIOR_FACETS
757- if arg . _is_dat :
758- if arg . _is_mixed :
759- packs = []
760- for a in arg :
761- shape = a . data . shape [ 1 :]
762- if shape == ():
763- shape = ( 1 , )
764- shape = ( None , * shape )
765- argument = Argument ( shape , a . data . dtype , pfx = "mdat" )
766- packs . append ( a . data . pack ( argument , arg . access , self . map_ ( a . map , unroll = a . unroll_map ),
767- interior_horizontal = interior_horizontal ,
768- init_with_zero = self . requires_zeroed_output_arguments ) )
769- self . arguments . append ( argument )
770- pack = MixedDatPack ( packs , arg . access , arg . dtype , interior_horizontal = interior_horizontal )
771- self . packed_args . append ( pack )
772- self .argument_accesses . append (arg .access )
757+
758+ if isinstance ( arg , GlobalKernelArg ) :
759+ argument = Argument ( arg . dim , dtype , pfx = "glob" )
760+
761+ pack = GlobalPack ( argument , access ,
762+ init_with_zero = self . requires_zeroed_output_arguments )
763+ self . arguments . append ( argument )
764+ elif isinstance ( arg , DatKernelArg ):
765+ if arg . dim == ():
766+ shape = ( None , 1 )
767+ else :
768+ shape = ( None , * arg . dim )
769+ argument = Argument ( shape , dtype , pfx = "dat" )
770+
771+ if arg . is_indirect :
772+ map_ = self ._add_map (arg .map_ )
773773 else :
774- if arg ._is_dat_view :
775- view_index = arg .data .index
776- data = arg .data ._parent
774+ map_ = None
775+ pack = arg .pack (argument , access , map_ = map_ ,
776+ interior_horizontal = interior_horizontal ,
777+ view_index = arg .index ,
778+ init_with_zero = self .requires_zeroed_output_arguments )
779+ self .arguments .append (argument )
780+ elif isinstance (arg , MixedDatKernelArg ):
781+ packs = []
782+ for a in arg :
783+ if a .dim == ():
784+ shape = (None , 1 )
785+ else :
786+ shape = (None , * a .dim )
787+ argument = Argument (shape , dtype , pfx = "mdat" )
788+
789+ if a .is_indirect :
790+ map_ = self ._add_map (a .map_ )
777791 else :
778- view_index = None
779- data = arg .data
780- shape = data .shape [1 :]
781- if shape == ():
782- shape = (1 ,)
783- shape = (None , * shape )
784- argument = Argument (shape ,
785- arg .data .dtype ,
786- pfx = "dat" )
787- pack = arg .data .pack (argument , arg .access , self .map_ (arg .map , unroll = arg .unroll_map ),
788- interior_horizontal = interior_horizontal ,
789- view_index = view_index ,
790- init_with_zero = self .requires_zeroed_output_arguments )
792+ map_ = None
793+
794+ packs .append (arg .pack (argument , access , map_ ,
795+ interior_horizontal = interior_horizontal ,
796+ init_with_zero = self .requires_zeroed_output_arguments ))
791797 self .arguments .append (argument )
792- self .packed_args .append (pack )
793- self .argument_accesses .append (arg .access )
794- elif arg ._is_global :
795- argument = Argument (arg .data .dim ,
796- arg .data .dtype ,
797- pfx = "glob" )
798- pack = GlobalPack (argument , arg .access ,
799- init_with_zero = self .requires_zeroed_output_arguments )
798+ pack = MixedDatPack (packs , access , dtype ,
799+ interior_horizontal = interior_horizontal )
800+ elif isinstance (arg , MatKernelArg ):
801+ argument = Argument ((), PetscMat (), pfx = "mat" )
802+ maps = tuple (self ._add_map (m , arg .unroll )
803+ for m in arg .maps )
804+ pack = arg .pack (argument , access , maps ,
805+ arg .dims , dtype ,
806+ interior_horizontal = interior_horizontal )
800807 self .arguments .append (argument )
801- self .packed_args .append (pack )
802- self .argument_accesses .append (arg .access )
803- elif arg ._is_mat :
804- if arg ._is_mixed :
805- packs = []
806- for a in arg :
807- argument = Argument ((), PetscMat (), pfx = "mat" )
808- map_ = tuple (self .map_ (m , unroll = arg .unroll_map ) for m in a .map )
809- packs .append (arg .data .pack (argument , a .access , map_ ,
810- a .data .dims , a .data .dtype ,
811- interior_horizontal = interior_horizontal ))
812- self .arguments .append (argument )
813- pack = MixedMatPack (packs , arg .access , arg .dtype ,
814- arg .data .sparsity .shape )
815- self .packed_args .append (pack )
816- self .argument_accesses .append (arg .access )
817- else :
808+ elif isinstance (arg , MixedMatKernelArg ):
809+ packs = []
810+ for a in arg :
818811 argument = Argument ((), PetscMat (), pfx = "mat" )
819- map_ = tuple (self .map_ (m , unroll = arg .unroll_map ) for m in arg .map )
820- pack = arg .data .pack (argument , arg .access , map_ ,
821- arg .data .dims , arg .data .dtype ,
822- interior_horizontal = interior_horizontal )
812+ maps = tuple (self ._add_map (m , a .unroll )
813+ for m in a .maps )
814+
815+ packs .append (arg .pack (argument , access , maps ,
816+ a .dims , dtype ,
817+ interior_horizontal = interior_horizontal ))
823818 self .arguments .append (argument )
824- self . packed_args . append ( pack )
825- self . argument_accesses . append ( arg .access )
819+ pack = MixedMatPack ( packs , access , dtype ,
820+ arg .shape )
826821 else :
827822 raise ValueError ("Unhandled argument type" )
828823
829- def map_ (self , map_ , unroll = False ):
824+ self .packed_args .append (pack )
825+ self .argument_accesses .append (access )
826+
827+ def _add_map (self , map_ , unroll = False ):
830828 if map_ is None :
831829 return None
832830 interior_horizontal = self .iteration_region == ON_INTERIOR_FACETS
833831 key = map_
834832 try :
835833 return self .maps [key ]
836834 except KeyError :
837- if isinstance (map_ , PermutedMap ):
838- imap = self .map_ (map_ .map_ , unroll = unroll )
839- map_ = PMap (imap , map_ .permutation )
835+ if isinstance (map_ , PermutedMapKernelArg ):
836+ imap = self ._add_map (map_ .base_map , unroll )
837+ map_ = PMap (imap , numpy . asarray ( map_ .permutation , dtype = IntType ) )
840838 else :
841- map_ = Map (map_ , interior_horizontal ,
839+ map_ = Map (interior_horizontal ,
842840 (self .bottom_layer , self .top_layer ),
843- unroll = unroll )
841+ arity = map_ .arity , offset = map_ .offset , dtype = IntType ,
842+ unroll = unroll ,
843+ extruded = self .extruded ,
844+ constant_layers = self .constant_layers )
844845 self .maps [key ] = map_
845846 return map_
846847
0 commit comments