@@ -989,27 +989,22 @@ def apply(self, **kwargs):
989989 # Get items expected to be specialized
990990 specialize = as_tuple (kwargs .pop ('specialize' , []))
991991
992- # In the case of specialization, arguments must be processed before
993- # the operator can be compiled
994992 if specialize :
995993 # FIXME: Cannot cope with things like sizes/strides yet since it only
996994 # looks at the parameters
997995
998996 # Build the arguments list for specialization
999- with self ._profiler .timer_on ('specialized-arguments-preprocess ' ):
997+ with self ._profiler .timer_on ('specialization ' ):
1000998 args = self .arguments (** kwargs )
1001- with switch_log_level (comm = args .comm ):
1002- self ._emit_args_profiling ('specialized-arguments-preprocess' )
999+ # Uses parameters here since Specializer needs {symbol: sympy value}
1000+ specialized_values = {p : sympify (args [p .name ])
1001+ for p in self .parameters if p .name in specialize }
10031002
1004- # Uses parameters here since Specializer needs {symbol: sympy value} mapper
1005- specialized_values = {p : sympify (args [p .name ])
1006- for p in self .parameters if p .name in specialize }
1003+ op = Specializer (specialized_values ).visit (self )
10071004
1008- op = Specializer (specialized_values ).visit (self )
1005+ with switch_log_level (comm = args .comm ):
1006+ self ._emit_args_profiling ('specialization' )
10091007
1010- # TODO: Does this cause problems for profilers?
1011- # FIXME: Need some way to inspect this Operator for testing
1012- # FIXME: Perhaps this should use some separate method
10131008 unspecialized_kwargs = {k : v for k , v in kwargs .items ()
10141009 if k not in specialize }
10151010
@@ -1025,9 +1020,7 @@ def apply(self, **kwargs):
10251020 with switch_log_level (comm = args .comm ):
10261021 self ._emit_args_profiling ('arguments-preprocess' )
10271022
1028- args_string = ", " .join ([f"{ p .name } ={ args [p .name ]} "
1029- for p in self .parameters if p .is_Symbol ])
1030- debug (f"Invoking `{ self .name } ` with scalar arguments: { args_string } " )
1023+ self ._emit_arguments (args )
10311024
10321025 # Invoke kernel function with args
10331026 arg_values = [args [p .name ] for p in self .parameters ]
@@ -1064,6 +1057,28 @@ def _emit_args_profiling(self, tag=''):
10641057 tagstr = ' ' .join (tag .split ('-' ))
10651058 debug (f"Operator `{ self .name } ` { tagstr } : { elapsed :.2f} s" )
10661059
1060+ def _emit_arguments (self , args ):
1061+ comm = args .comm
1062+ scalar_args = ", " .join ([f"{ p .name } ={ args [p .name ]} "
1063+ for p in self .parameters
1064+ if p .is_Symbol ])
1065+
1066+ rank = f"[rank{ args .comm .Get_rank ()} ] " if comm is not MPI .COMM_NULL else ""
1067+
1068+ msg = f"* { rank } { scalar_args } "
1069+
1070+ with switch_log_level (comm = comm ):
1071+ debug (f"Scalar arguments used to invoke `{ self .name } `" )
1072+
1073+ if comm is not MPI .COMM_NULL :
1074+ # With MPI enabled, we add one entry per rank
1075+ allmsg = comm .allgather (msg )
1076+ if comm .Get_rank () == 0 :
1077+ for m in allmsg :
1078+ debug (m )
1079+ else :
1080+ debug (msg )
1081+
10671082 def _emit_build_profiling (self ):
10681083 if not is_log_enabled_for ('PERF' ):
10691084 return
0 commit comments