@@ -251,5 +251,58 @@ def test_power(self):
251
251
self ._test_power ((0 , 0 ))
252
252
253
253
254
+ class TestPowerAPI_Alias (unittest .TestCase ):
255
+ """
256
+ Test the alias of pow function.
257
+ ``pow(input=2, exponent=1.1)`` is equivalent to ``pow(x=2, y=1.1)``
258
+ """
259
+
260
+ def setUp (self ):
261
+ self .places = get_devices ()
262
+ self .test_cases = [
263
+ ([1.0 , 2.0 , 3.0 ], [1.1 ]), # 1D tensor
264
+ ([[1 , 2 ], [3 , 4 ]], 2 ), # 2D tensor with scalar exponent
265
+ (3.0 , [2.0 ]), # Scalar input
266
+ ]
267
+
268
+ def test_powxy (self ):
269
+ for alias_param_1 in ["x" , "input" ]:
270
+ for alias_param_2 in ["y" , "exponent" ]:
271
+ for place in self .places :
272
+ paddle .set_device (place )
273
+ paddle .disable_static (place )
274
+ for input_data , exp_data in self .test_cases :
275
+ input_tensor = paddle .to_tensor (input_data )
276
+ exp_tensor = paddle .to_tensor (exp_data )
277
+ output_alias = paddle .pow (
278
+ ** {
279
+ alias_param_1 : input_tensor ,
280
+ alias_param_2 : exp_tensor ,
281
+ }
282
+ )
283
+ output_std = paddle .pow (x = input_tensor , y = exp_tensor )
284
+ self .assertTrue (
285
+ paddle .allclose (output_alias , output_std ),
286
+ msg = f"Alias { alias_param_1 } /{ alias_param_2 } failed on { place } with input { input_data } , exp { exp_data } " ,
287
+ )
288
+
289
+ def test_xpowy (self ):
290
+ for alias_param_2 in ["y" , "exponent" ]:
291
+ for place in self .places :
292
+ paddle .set_device (place )
293
+ paddle .disable_static (place )
294
+ for input_data , exp_data in self .test_cases :
295
+ input_tensor = paddle .to_tensor (input_data )
296
+ exp_tensor = paddle .to_tensor (exp_data )
297
+ output_alias = input_tensor .pow (
298
+ ** {alias_param_2 : exp_tensor }
299
+ )
300
+ output_std = input_tensor .pow (y = exp_tensor )
301
+ self .assertTrue (
302
+ paddle .allclose (output_alias , output_std ),
303
+ msg = f"Alias { alias_param_2 } failed on { place } with input { input_data } , exp { exp_data } " ,
304
+ )
305
+
306
+
254
307
if __name__ == '__main__' :
255
308
unittest .main ()
0 commit comments