@@ -26,7 +26,7 @@ def forward(self, x):
2626 return self .classifier (x )
2727
2828
29- class ArgsTest (unittest .TestCase ):
29+ class GlobalsTest (unittest .TestCase ):
3030 def testGlobalParameters (self ):
3131 m = SimpleParams ()
3232
@@ -63,10 +63,6 @@ def read_params(self):
6363 "%_params.classifier.bias = util.global.load @_params.classifier.bias" ,
6464 module_str ,
6565 )
66- self .assertIn (
67- "return %_params.classifier.weight, %_params.classifier.bias" ,
68- module_str ,
69- )
7066
7167 def testGlobalLoadFromPyLeaf (self ):
7268 m = SimpleParams ()
@@ -84,7 +80,6 @@ def read_weight(self):
8480 "%_params.classifier.weight = util.global.load @_params.classifier.weight" ,
8581 module_str ,
8682 )
87- self .assertIn ("return %_params.classifier.weight" , module_str )
8883
8984 def testGlobalStoreFromPyTree (self ):
9085 m = SimpleParams ()
@@ -100,8 +95,10 @@ def update_params(me, updates=abstractify(params)):
10095 inst = GlobalModule (context = Context ())
10196 module_str = str (CompiledModule .get_mlir_module (inst ))
10297 print (module_str )
103- self .assertIn ("util.global.store %arg0, @_params.classifier.weight" , module_str )
104- self .assertIn ("util.global.store %arg1, @_params.classifier.bias" , module_str )
98+ self .assertRegex (
99+ module_str , "util.global.store %.*, @_params.classifier.weight"
100+ )
101+ self .assertRegex (module_str , "util.global.store %.*, @_params.classifier.bias" )
105102
106103 def testGlobalStoreFromLeaf (self ):
107104 m = SimpleParams ()
@@ -115,7 +112,7 @@ def update_bias(self, new_bias=abstractify(params["classifier.bias"])):
115112 inst = GlobalModule (context = Context ())
116113 module_str = str (CompiledModule .get_mlir_module (inst ))
117114 print (module_str )
118- self .assertIn ( "util.global.store %arg0 , @_params.classifier.bias" , module_str )
115+ self .assertRegex ( module_str , "util.global.store %.* , @_params.classifier.bias" )
119116
120117 def testExportSingleGlobalTensor (self ):
121118 state_example = torch .randn (3 , 11 )
@@ -131,7 +128,6 @@ def read_state(self):
131128 print (module_str )
132129 self .assertIn ("util.global private @_state0.global" , module_str )
133130 self .assertIn ("%_state0.global = util.global.load @_state0.global" , module_str )
134- self .assertIn ("return %_state0.global" , module_str )
135131
136132 def testExportTreeGlobalTensors (self ):
137133 state_example = {
@@ -160,10 +156,6 @@ def read_state(self):
160156 self .assertIn ("%_state0.seq.0 = util.global.load @_state0.seq.0" , module_str )
161157 self .assertIn ("%_state0.seq.1 = util.global.load @_state0.seq.1" , module_str )
162158 self .assertIn ("%_state0.seq.2 = util.global.load @_state0.seq.2" , module_str )
163- self .assertIn (
164- "return %_state0.data, %_state0.seq.0, %_state0.seq.1, %_state0.seq.2" ,
165- module_str ,
166- )
167159
168160 def testExportGlobalScalars (self ):
169161 class ScalarState (CompiledModule ):
@@ -210,9 +202,6 @@ class DerivedState(BaseState):
210202 print (module_str )
211203 self .assertIn ("@_state_index.global {noinline} = 0 : index" , module_str )
212204 self .assertIn ("@_state_f32.global {noinline} = 0.000000e+00 : f32" , module_str )
213- self .assertIn (
214- "return %_state_index.global, %_state_f32.global : index, f32" , module_str
215- )
216205
217206 def testInheritOverrideBase (self ):
218207 class BaseState (CompiledModule ):
@@ -252,8 +241,10 @@ class DerivedModule(BaseModule):
252241 inst = DerivedModule (context = Context ())
253242 module_str = str (CompiledModule .get_mlir_module (inst ))
254243 print (module_str )
255- self .assertIn ("util.global.store %arg0, @_params.classifier.weight" , module_str )
256- self .assertIn ("util.global.store %arg1, @_params.classifier.bias" , module_str )
244+ self .assertRegex (
245+ module_str , "util.global.store %.*, @_params.classifier.weight"
246+ )
247+ self .assertRegex (module_str , "util.global.store %.*, @_params.classifier.bias" )
257248
258249 def testUpdateGlobalStateTree (self ):
259250 state_example = {
@@ -287,10 +278,10 @@ def read_state(self, updates=abstractify(state_example)):
287278 module_str ,
288279 )
289280 self .assertIn ("util.global private mutable @_state0.data" , module_str )
290- self .assertIn ( "util.global.store %arg0 , @_state0.data" , module_str )
291- self .assertIn ( "util.global.store %arg1 , @_state0.seq.0" , module_str )
292- self .assertIn ( "util.global.store %arg2 , @_state0.seq.1" , module_str )
293- self .assertIn ( "util.global.store %arg3 , @_state0.seq.2" , module_str )
281+ self .assertRegex ( module_str , "util.global.store %.* , @_state0.data" )
282+ self .assertRegex ( module_str , "util.global.store %.* , @_state0.seq.0" )
283+ self .assertRegex ( module_str , "util.global.store %.* , @_state0.seq.1" )
284+ self .assertRegex ( module_str , "util.global.store %.* , @_state0.seq.2" )
294285
295286 def testTensorUpdateGlobal (self ):
296287 state_example = torch .randn (5 , 20 )
@@ -305,9 +296,9 @@ def tensor_update_state(self, update=abstractify(update_example)):
305296 inst = UpdateState (context = Context ())
306297 module_str = str (CompiledModule .get_mlir_module (inst ))
307298 print (module_str )
308- self .assertIn (
309- "flow.tensor.update %arg0, %_state0.global[%c0, %c0] : tensor<1x20xf32> -> %_state0.global as tensor<5x20xf32>" ,
299+ self .assertRegex (
310300 module_str ,
301+ "flow.tensor.update %.*, %_state0.global\\ [%c0, %c0\\ ] : tensor<1x20xf32> -> %_state0.global as tensor<5x20xf32>" ,
311302 )
312303
313304 def testTensorUpdateGlobalReturnNone (self ):
@@ -325,10 +316,7 @@ def tensor_update_state(self, update=abstractify(update_example)):
325316 inst = UpdateState (context = Context ())
326317 module_str = str (CompiledModule .get_mlir_module (inst ))
327318 print (module_str )
328- self .assertIn (
329- "flow.tensor.update %arg0, %_state0.global[%c4, %c0, %c0] : tensor<1x1x4xf32> -> %_state0.global as tensor<5x20x4xf32>" ,
330- module_str ,
331- )
319+ self .assertIn ("flow.tensor.update" , module_str )
332320
333321 def testExternalGlobalParametersDefaults (self ):
334322 m = SimpleParams ()
0 commit comments