|
29 | 29 | from torch._dynamo.testing import normalize_gm |
30 | 30 | from torch._dynamo.utils import counters |
31 | 31 | from torch._inductor import config as inductor_config |
| 32 | +from torch._inductor.cpp_builder import is_msvc_cl |
32 | 33 | from torch._inductor.test_case import run_tests, TestCase |
33 | 34 | from torch.nn.attention.flex_attention import flex_attention |
34 | 35 | from torch.nn.parallel import DistributedDataParallel as DDP |
|
40 | 41 | from torch.testing._internal.common_utils import ( |
41 | 42 | instantiate_parametrized_tests, |
42 | 43 | IS_S390X, |
| 44 | + IS_WINDOWS, |
43 | 45 | parametrize, |
44 | 46 | scoped_load_inline, |
45 | 47 | skipIfWindows, |
@@ -193,6 +195,18 @@ def model(i): |
193 | 195 | for _ in range(3): |
194 | 196 | self.run_as_subprocess(script) |
195 | 197 |
|
| 198 | + def gen_cache_miss_log_prefix(self): |
| 199 | + if IS_WINDOWS: |
| 200 | + if is_msvc_cl(): |
| 201 | + return "Cache miss due to new autograd node: struct " |
| 202 | + else: |
| 203 | + self.fail( |
| 204 | + "Compilers other than msvc have not yet been verified on Windows." |
| 205 | + ) |
| 206 | + return "" |
| 207 | + else: |
| 208 | + return "Cache miss due to new autograd node: " |
| 209 | + |
196 | 210 | def test_reset(self): |
197 | 211 | compiled_autograd.compiled_autograd_enabled = True |
198 | 212 | torch._C._dynamo.compiled_autograd.set_autograd_compiler(lambda: None, True) |
@@ -3146,7 +3160,7 @@ def test_logs(self): |
3146 | 3160 | self.assertEqual(counters["compiled_autograd"]["compiles"], 1) |
3147 | 3161 | assert "torch::autograd::AccumulateGrad (NodeCall" in logs.getvalue() |
3148 | 3162 | assert ( |
3149 | | - "Cache miss due to new autograd node: torch::autograd::GraphRoot" |
| 3163 | + self.gen_cache_miss_log_prefix() + "torch::autograd::GraphRoot" |
3150 | 3164 | not in logs.getvalue() |
3151 | 3165 | ) |
3152 | 3166 |
|
@@ -3353,7 +3367,6 @@ def fn(x, obj): |
3353 | 3367 | sum(1 for e in expected_logs if e in logs.getvalue()), len(expected_logs) |
3354 | 3368 | ) |
3355 | 3369 |
|
3356 | | - @skipIfWindows(msg="AssertionError: Scalars are not equal!") |
3357 | 3370 | def test_verbose_logs_cpp(self): |
3358 | 3371 | torch._logging.set_logs(compiled_autograd_verbose=True) |
3359 | 3372 |
|
@@ -3381,8 +3394,9 @@ def fn(): |
3381 | 3394 | self.check_output_and_recompiles(fn) |
3382 | 3395 |
|
3383 | 3396 | patterns1 = [ |
3384 | | - r".*Cache miss due to new autograd node: torch::autograd::GraphRoot \(NodeCall 0\) with key size (\d+), " |
3385 | | - r"previous key sizes=\[\]\n", |
| 3397 | + r".*" |
| 3398 | + + self.gen_cache_miss_log_prefix() |
| 3399 | + + r"torch::autograd::GraphRoot \(NodeCall 0\) with key size (\d+), previous key sizes=\[\]\n", |
3386 | 3400 | ] |
3387 | 3401 |
|
3388 | 3402 | all_logs = logs.getvalue() |
@@ -3420,7 +3434,8 @@ def test_verbose_logs_dynamic_shapes(self): |
3420 | 3434 |
|
3421 | 3435 | actual_logs = logs.getvalue() |
3422 | 3436 | expected_logs = [ |
3423 | | - "Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]", |
| 3437 | + self.gen_cache_miss_log_prefix() |
| 3438 | + + "torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]", |
3424 | 3439 | ] |
3425 | 3440 | for expected in expected_logs: |
3426 | 3441 | self.assertTrue(expected in actual_logs) |
@@ -3451,7 +3466,7 @@ def fn(): |
3451 | 3466 | fn() |
3452 | 3467 |
|
3453 | 3468 | unexpected_logs = [ |
3454 | | - "Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0)" |
| 3469 | + self.gen_cache_miss_log_prefix() + "torch::autograd::GraphRoot (NodeCall 0)" |
3455 | 3470 | ] |
3456 | 3471 |
|
3457 | 3472 | self.assertEqual(sum(1 for e in unexpected_logs if e in logs.getvalue()), 0) |
|
0 commit comments