|
28 | 28 | TuningProcessPool,
|
29 | 29 | )
|
30 | 30 | from torch._inductor.graph import GraphLowering
|
31 |
| -from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout, InputBuffer |
| 31 | +from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout |
32 | 32 | from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
|
33 |
| -from torch._inductor.kernel_inputs import MMKernelInputs |
34 | 33 | from torch._inductor.select_algorithm import (
|
35 | 34 | add_feedback_saver,
|
36 | 35 | AlgorithmSelectorCache,
|
|
76 | 75 | )
|
77 | 76 |
|
78 | 77 |
|
79 |
| -torch.backends.cuda.matmul.allow_tf32 = True |
| 78 | +torch.set_float32_matmul_precision("high") |
80 | 79 | if HAS_CUDA_AND_TRITON:
|
81 | 80 | torch.cuda.memory._set_allocator_settings("expandable_segments:False")
|
82 | 81 |
|
@@ -2077,39 +2076,6 @@ def f(x, y):
|
2077 | 2076 | global_stats.report()
|
2078 | 2077 | self.assertEqual(global_stats.autotune_remote, Stats(2, 3, 2))
|
2079 | 2078 |
|
2080 |
| - def test_get_mm_configs_float32_precision_ieee(self): |
2081 |
| - """Test that configs returned from choices.get_mm_configs use float32_precision == ieee.""" |
2082 |
| - from torch._inductor.choices import InductorChoices |
2083 |
| - from torch._inductor.graph import GraphLowering |
2084 |
| - from torch._inductor.ir import FlexibleLayout |
2085 |
| - from torch.fx.experimental.proxy_tensor import make_fx |
2086 |
| - |
2087 |
| - # Create a simple graph to get proper context |
2088 |
| - gm = make_fx(lambda: torch.zeros(2, 3))() |
2089 |
| - graph = GraphLowering(gm) |
2090 |
| - |
2091 |
| - with V.set_graph_handler(graph): |
2092 |
| - device = torch.device(f"{GPU_TYPE}:0") |
2093 |
| - mat1 = InputBuffer( |
2094 |
| - name="mat1", |
2095 |
| - layout=FixedLayout(device, torch.float32, [64, 128], [128, 1]), |
2096 |
| - ) |
2097 |
| - mat2 = InputBuffer( |
2098 |
| - name="mat2", |
2099 |
| - layout=FixedLayout(device, torch.float32, [128, 64], [64, 1]), |
2100 |
| - ) |
2101 |
| - kernel_inputs = MMKernelInputs([mat1, mat2]) |
2102 |
| - output_layout = FlexibleLayout(device, torch.float32, [64, 64]) |
2103 |
| - |
2104 |
| - choices = InductorChoices() |
2105 |
| - configs = list( |
2106 |
| - choices.get_mm_configs(kernel_inputs, output_layout, "mm", "mm") |
2107 |
| - ) |
2108 |
| - |
2109 |
| - for cfg in configs: |
2110 |
| - self.assertIn("ALLOW_TF32", cfg) |
2111 |
| - self.assertEqual(cfg["ALLOW_TF32"], True) |
2112 |
| - |
2113 | 2079 |
|
2114 | 2080 | class _TestTritonTemplateCaller(TritonTemplateCaller):
|
2115 | 2081 | def __init__(self, bmreq: _TestBenchmarkRequest):
|
|
0 commit comments