From 436a043278c3def645d9d7e965edca4f191f171f Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 18 Aug 2025 13:27:30 -0400 Subject: [PATCH] Update [ghstack-poisoned] --- test/specdb/test_specdb_mps.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/specdb/test_specdb_mps.py b/test/specdb/test_specdb_mps.py index ea915a2..b81d97b 100644 --- a/test/specdb/test_specdb_mps.py +++ b/test/specdb/test_specdb_mps.py @@ -66,6 +66,35 @@ def test_all_ops_mps_half(self): ) self._run_all_ops(config=config, skip_ops=skip_ops) + def test_all_ops_mps_transposed(self): + skip_ops = self.SKIP_OPS.copy() + # argmax.default ['ArgType.Tensor torch.float32 (8, 8)', 'ArgType.DimOpt None', 'ArgType.Bool True'] + # Exception occurred: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + skip_ops += ["argmax.default", "argmin.default"] + + config = TensorConfig( + device="mps", disallow_dtypes=[torch.float64], transposed=True + ) + self._run_all_ops(config=config, skip_ops=skip_ops) + + def test_all_ops_mps_permuted(self): + skip_ops = self.SKIP_OPS.copy() + skip_ops += ["argmax.default", "argmin.default"] + + config = TensorConfig( + device="mps", disallow_dtypes=[torch.float64], permuted=True + ) + self._run_all_ops(config=config, skip_ops=skip_ops) + + def test_all_ops_mps_strided(self): + skip_ops = self.SKIP_OPS.copy() + skip_ops += ["argmax.default", "argmin.default"] + + config = TensorConfig( + device="mps", disallow_dtypes=[torch.float64], strided=True + ) + self._run_all_ops(config=config, skip_ops=skip_ops) + if __name__ == "__main__": unittest.main()