diff --git a/test/specdb/test_specdb_mps.py b/test/specdb/test_specdb_mps.py index a367f64..fc8128c 100644 --- a/test/specdb/test_specdb_mps.py +++ b/test/specdb/test_specdb_mps.py @@ -66,6 +66,35 @@ def test_all_ops_mps_half(self): ) self._run_all_ops(config=config, skip_ops=skip_ops) + def test_all_ops_mps_transposed(self): + skip_ops = self.SKIP_OPS.copy() + # argmax.default ['ArgType.Tensor torch.float32 (8, 8)', 'ArgType.DimOpt None', 'ArgType.Bool True'] + # Exception occurred: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + skip_ops += ["argmax.default", "argmin.default"] + + config = TensorConfig( + device="mps", disallow_dtypes=[torch.float64], transposed=True + ) + self._run_all_ops(config=config, skip_ops=skip_ops) + + def test_all_ops_mps_permuted(self): + skip_ops = self.SKIP_OPS.copy() + skip_ops += ["argmax.default", "argmin.default"] + + config = TensorConfig( + device="mps", disallow_dtypes=[torch.float64], permuted=True + ) + self._run_all_ops(config=config, skip_ops=skip_ops) + + def test_all_ops_mps_strided(self): + skip_ops = self.SKIP_OPS.copy() + skip_ops += ["argmax.default", "argmin.default"] + + config = TensorConfig( + device="mps", disallow_dtypes=[torch.float64], strided=True + ) + self._run_all_ops(config=config, skip_ops=skip_ops) + if __name__ == "__main__": unittest.main()