@@ -40,6 +40,7 @@ def is_interpreter():
4040 ("device_print_negative" , "int32" ),
4141 ("device_print_uint" , "uint32" ),
4242 ("device_print_2d_tensor" , "int32" ),
43+ ("device_print" , "bool" ),
4344 ])
4445def test_print (func_type : str , data_type : str , device : str ):
4546 if device == "xpu" and data_type == "float64" and not tr .driver .active .get_current_target ().arch ['has_fp64' ]:
@@ -66,7 +67,7 @@ def test_print(func_type: str, data_type: str, device: str):
6667 # Format is
6768 # pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...) <prefix> (operand <n>) <elem>
6869 expected_lines = Counter ()
69- if func_type in ("print" , "device_print" , "device_print_uint" ):
70+ if func_type in ("print" , "device_print" , "device_print_uint" ) and data_type != "bool" :
7071 for i in range (N ):
7172 offset = (1 << 31 ) if data_type == "uint32" else 0
7273 line = f"pid (0, 0, 0) idx ({ i :3} ) x: { i + offset } "
@@ -115,6 +116,10 @@ def test_print(func_type: str, data_type: str, device: str):
115116 for x in range (x_dim ):
116117 for y in range (y_dim ):
117118 expected_lines [f"pid (0, 0, 0) idx ({ x } , { y :2} ): { (x * y_dim + y )} " ] = 1
119+ elif data_type == "bool" :
120+ expected_lines ["pid (0, 0, 0) idx ( 0) x: 0" ] = 1
121+ for i in range (1 , N ):
122+ expected_lines [f"pid (0, 0, 0) idx ({ i :3} ) x: 1" ] = 1
118123
119124 actual_lines = Counter ()
120125 for line in outs :
0 commit comments