-
Notifications
You must be signed in to change notification settings - Fork 75
Closed
Closed
Copy link
Description
Kernels from #5253 have verification issue on large shapes, for example on:
BATCHED_MM_X_VALS = [
(256, 16, 7168, 2048),
]
Even if inputs are modified to:
A_q[:] = 0
A_q[:, 0, 0] = 1
B_q[:] = 0
B_q[:, 0, 0] = 0
, which should give 0 outputs we get:
output.max(axis=(1, 2))
array([0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 3.140625 , 3.3125 , 2.890625 ,
0. , 3.3125 , 0. , 3.25 , 3.203125 ,
3.09375 , 3.40625 , 3.359375 , 3.421875 , 3.953125 ,
4.09375 , 3.859375 , 3.515625 , 3.390625 , 3.09375 ,
0.20898438, 0.2578125 , 0.21582031, 0.20507812, 0.20898438,
0.20898438, 0.27734375, 0.19335938, 0.24804688, 0. ,
0.22753906, 0.21289062, 0.20898438, 0. , 0.23730469,
0.23046875, 0.20996094, 0.3046875 , 0.22753906, 0.20800781,
0.18945312, 0.21386719, 0.22460938, 0.2734375 , 0.25390625,
0.24023438, 0.22753906, 0.19921875, 0.23632812, 0.21582031,
0.25976562, 0.23535156, 0.20214844, 0.21191406, 0.25195312,
0.23632812, 0.25585938, 0.21191406, 0. , 0. ,
0.22851562, 0.26953125, 0.24902344, 0.20214844, 0.22558594,
0.25195312, 0.19628906, 0.19824219, 0.23828125, 0.3046875 ,
0.23535156, 0.22949219, 0.19824219, 0.23828125, 0. ,
0.22167969, 0. , 0.19433594, 0.21777344, 0.25976562,
0.24902344, 0.19628906, 0.22753906, 0. , 0.21289062,
0.24414062, 0.21386719, 0.21191406, 0.25195312, 0.20996094,
0.19824219, 0.20703125, 0.22070312, 0.19042969, 0.19335938,
0.20996094, 0.20800781, 0.27929688, 0.20703125, 0.24414062,
0.23144531, 0.21972656, 0.20019531, 0.23535156, 0.23144531,
0. , 0.22167969, 0.35742188, 0.234375 , 0.23242188,
0.25390625], dtype=float32)