Skip to content

Commit 4636378

Browse files
Test mps ops with non-default memory layouts
ghstack-source-id: d7ed79f Pull-Request: #40
1 parent b774721 commit 4636378

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

test/specdb/test_specdb_mps.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,35 @@ def test_all_ops_mps_half(self):
6666
)
6767
self._run_all_ops(config=config, skip_ops=skip_ops)
6868

69+
def test_all_ops_mps_transposed(self):
70+
skip_ops = self.SKIP_OPS.copy()
71+
# argmax.default ['ArgType.Tensor torch.float32 (8, 8)', 'ArgType.DimOpt None', 'ArgType.Bool True']
72+
# Exception occurred: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
73+
skip_ops += ["argmax.default", "argmin.default"]
74+
75+
config = TensorConfig(
76+
device="mps", disallow_dtypes=[torch.float64], transposed=True
77+
)
78+
self._run_all_ops(config=config, skip_ops=skip_ops)
79+
80+
def test_all_ops_mps_permuted(self):
81+
skip_ops = self.SKIP_OPS.copy()
82+
skip_ops += ["argmax.default", "argmin.default"]
83+
84+
config = TensorConfig(
85+
device="mps", disallow_dtypes=[torch.float64], permuted=True
86+
)
87+
self._run_all_ops(config=config, skip_ops=skip_ops)
88+
89+
def test_all_ops_mps_strided(self):
90+
skip_ops = self.SKIP_OPS.copy()
91+
skip_ops += ["argmax.default", "argmin.default"]
92+
93+
config = TensorConfig(
94+
device="mps", disallow_dtypes=[torch.float64], strided=True
95+
)
96+
self._run_all_ops(config=config, skip_ops=skip_ops)
97+
6998

7099
if __name__ == "__main__":
71100
unittest.main()

0 commit comments

Comments
 (0)