Commit 99aa09a
Fix CUDA flaky tests for stochastic gates by using CPU-seeded RNG (#1802)
Summary:
Pull Request resolved: #1802
## Problem
CUDA RNG produces different random sequences on different GPU architectures (e.g. V100 vs A100 vs H100) even with the same seed set via `torch.manual_seed()`. This causes stochastic gate CUDA tests to be flaky in CI — the same test passes on one GPU type but fails on another because expected values were hardcoded for a specific architecture's RNG output.
Additionally, `test_p_norm_decay` uses exact `assert ==` for floating-point tensor comparison, which fails on GPU due to floating-point precision differences.
## Solution
**CPU-seeded RNG approach**: In CUDA test subclasses, patch `_sample_gate_values` to generate random noise on CPU (where `torch.manual_seed` is deterministic across all hardware) and then move the tensor to the GPU device. This keeps the full training codepath exercised (noise + mu → clamp → gather → multiply) while ensuring cross-architecture determinism.
For `LazyGaussianStochasticGates`, both `initialize_parameters` (mu initialization) and `_sample_gate_values` (noise sampling) happen on-device after `.to(cuda)`, so both are patched to use CPU RNG.
Since CUDA tests now produce identical values to CPU tests, the `if cpu / elif cuda` branches in base test files are removed, along with associated `pyre-fixme[61]` comments.
For `test_p_norm_decay`, exact `assert ==` is replaced with `assertTensorAlmostEqual` with `delta=0.01` tolerance.
## Files Changed
- `test_gaussian_stochastic_gates_cuda.py`: Patch `_sample_gate_values` with CPU-seeded `normal_()` sampling
- `test_kuma_stochastic_gates_cuda.py`: Patch `_sample_gate_values` with CPU-seeded `uniform_()` sampling + Kumaraswamy transform
- `test_lazy_gaussian_stochastic_gates_cuda.py`: Patch both `initialize_parameters` and `_sample_gate_values`
- `test_gaussian_stochastic_gates.py`: Remove cpu/cuda branches (4 tests)
- `test_kuma_stochastic_gates.py`: Remove cpu/cuda branches (4 tests)
- `test_lazy_gaussian_stochastic_gates.py`: Remove cpu/cuda branches (12 tests)
- `test_p_norm_decay.py`: Use `assertTensorAlmostEqual` instead of exact equality (2 tests)
Reviewed By: craymichael
Differential Revision: D97775614
fbshipit-source-id: 348ad6f317838fd5577fcc19a18bed256905c5b71 parent 458e134 commit 99aa09a
File tree
4 files changed
+96
-70
lines changed- tests/module
4 files changed
+96
-70
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
32 | 32 | | |
33 | 33 | | |
34 | 34 | | |
35 | | - | |
36 | | - | |
37 | | - | |
38 | | - | |
| 35 | + | |
39 | 36 | | |
40 | | - | |
41 | 37 | | |
42 | 38 | | |
43 | 39 | | |
| |||
110 | 106 | | |
111 | 107 | | |
112 | 108 | | |
113 | | - | |
114 | | - | |
115 | | - | |
116 | | - | |
| 109 | + | |
117 | 110 | | |
118 | | - | |
119 | 111 | | |
120 | 112 | | |
121 | 113 | | |
| |||
143 | 135 | | |
144 | 136 | | |
145 | 137 | | |
146 | | - | |
147 | | - | |
148 | | - | |
149 | | - | |
150 | | - | |
151 | | - | |
152 | | - | |
153 | | - | |
154 | | - | |
155 | | - | |
156 | | - | |
157 | | - | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
158 | 142 | | |
159 | 143 | | |
160 | 144 | | |
| |||
207 | 191 | | |
208 | 192 | | |
209 | 193 | | |
210 | | - | |
211 | | - | |
212 | | - | |
213 | | - | |
214 | | - | |
215 | | - | |
216 | | - | |
217 | | - | |
218 | | - | |
219 | | - | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
220 | 198 | | |
221 | | - | |
222 | 199 | | |
223 | 200 | | |
224 | 201 | | |
| |||
Lines changed: 42 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3 | 3 | | |
4 | 4 | | |
5 | 5 | | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
6 | 13 | | |
7 | 14 | | |
8 | 15 | | |
9 | | - | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
10 | 37 | | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
33 | 33 | | |
34 | 34 | | |
35 | 35 | | |
| 36 | + | |
36 | 37 | | |
37 | | - | |
38 | | - | |
39 | | - | |
40 | | - | |
41 | | - | |
42 | | - | |
43 | 38 | | |
44 | 39 | | |
45 | 40 | | |
| |||
90 | 85 | | |
91 | 86 | | |
92 | 87 | | |
| 88 | + | |
93 | 89 | | |
94 | | - | |
95 | | - | |
96 | | - | |
97 | | - | |
98 | | - | |
99 | | - | |
100 | 90 | | |
101 | 91 | | |
102 | 92 | | |
| |||
137 | 127 | | |
138 | 128 | | |
139 | 129 | | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
140 | 134 | | |
141 | | - | |
142 | | - | |
143 | | - | |
144 | | - | |
145 | | - | |
146 | | - | |
147 | | - | |
148 | | - | |
149 | | - | |
150 | | - | |
151 | | - | |
152 | | - | |
153 | 135 | | |
154 | 136 | | |
155 | 137 | | |
| |||
200 | 182 | | |
201 | 183 | | |
202 | 184 | | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
203 | 189 | | |
204 | | - | |
205 | | - | |
206 | | - | |
207 | | - | |
208 | | - | |
209 | | - | |
210 | | - | |
211 | | - | |
212 | | - | |
213 | | - | |
214 | | - | |
215 | | - | |
216 | 190 | | |
217 | 191 | | |
218 | 192 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3 | 3 | | |
4 | 4 | | |
5 | 5 | | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
6 | 13 | | |
7 | 14 | | |
8 | 15 | | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
9 | 29 | | |
10 | 30 | | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
0 commit comments