Skip to content

Commit 99aa09a

Browse files
cyrjanometa-codesync[bot]
authored andcommitted
Fix CUDA flaky tests for stochastic gates by using CPU-seeded RNG (#1802)
Summary: Pull Request resolved: #1802 ## Problem CUDA RNG produces different random sequences on different GPU architectures (e.g. V100 vs A100 vs H100) even with the same seed set via `torch.manual_seed()`. This causes stochastic gate CUDA tests to be flaky in CI — the same test passes on one GPU type but fails on another because expected values were hardcoded for a specific architecture's RNG output. Additionally, `test_p_norm_decay` uses exact `assert ==` for floating-point tensor comparison, which fails on GPU due to floating-point precision differences. ## Solution **CPU-seeded RNG approach**: In CUDA test subclasses, patch `_sample_gate_values` to generate random noise on CPU (where `torch.manual_seed` is deterministic across all hardware) and then move the tensor to the GPU device. This keeps the full training codepath exercised (noise + mu → clamp → gather → multiply) while ensuring cross-architecture determinism. For `LazyGaussianStochasticGates`, both `initialize_parameters` (mu initialization) and `_sample_gate_values` (noise sampling) happen on-device after `.to(cuda)`, so both are patched to use CPU RNG. Since CUDA tests now produce identical values to CPU tests, the `if cpu / elif cuda` branches in base test files are removed, along with associated `pyre-fixme[61]` comments. For `test_p_norm_decay`, exact `assert ==` is replaced with `assertTensorAlmostEqual` with `delta=0.01` tolerance. ## Files Changed - `test_gaussian_stochastic_gates_cuda.py`: Patch `_sample_gate_values` with CPU-seeded `normal_()` sampling - `test_kuma_stochastic_gates_cuda.py`: Patch `_sample_gate_values` with CPU-seeded `uniform_()` sampling + Kumaraswamy transform - `test_lazy_gaussian_stochastic_gates_cuda.py`: Patch both `initialize_parameters` and `_sample_gate_values` - `test_gaussian_stochastic_gates.py`: Remove cpu/cuda branches (4 tests) - `test_kuma_stochastic_gates.py`: Remove cpu/cuda branches (4 tests) - `test_lazy_gaussian_stochastic_gates.py`: Remove cpu/cuda branches (12 tests) - `test_p_norm_decay.py`: Use `assertTensorAlmostEqual` instead of exact equality (2 tests) Reviewed By: craymichael Differential Revision: D97775614 fbshipit-source-id: 348ad6f317838fd5577fcc19a18bed256905c5b7
1 parent 458e134 commit 99aa09a

File tree

4 files changed

+96
-70
lines changed

4 files changed

+96
-70
lines changed

tests/module/test_binary_concrete_stochastic_gates.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,8 @@ def test_bcstg_1d_input(self) -> None:
3232
gated_input, reg = bcstg(input_tensor)
3333
expected_reg = 2.4947
3434

35-
if self.testing_device == "cpu":
36-
expected_gated_input = [[0.0000, 0.0212, 0.1892], [0.1839, 0.3753, 0.4937]]
37-
elif self.testing_device == "cuda":
38-
expected_gated_input = [[0.0000, 0.0985, 0.1149], [0.2329, 0.0497, 0.5000]]
35+
expected_gated_input = [[0.0000, 0.0212, 0.1892], [0.1839, 0.3753, 0.4937]]
3936

40-
# pyre-fixme[61]: `expected_gated_input` is undefined, or not always defined.
4137
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
4238
assertTensorAlmostEqual(self, reg, expected_reg)
4339

@@ -110,12 +106,8 @@ def test_bcstg_1d_input_with_mask(self) -> None:
110106
gated_input, reg = bcstg(input_tensor)
111107
expected_reg = 1.6643
112108

113-
if self.testing_device == "cpu":
114-
expected_gated_input = [[0.0000, 0.0000, 0.1679], [0.0000, 0.0000, 0.2223]]
115-
elif self.testing_device == "cuda":
116-
expected_gated_input = [[0.0000, 0.0000, 0.1971], [0.1737, 0.2317, 0.3888]]
109+
expected_gated_input = [[0.0000, 0.0000, 0.1679], [0.0000, 0.0000, 0.2223]]
117110

118-
# pyre-fixme[61]: `expected_gated_input` is undefined, or not always defined.
119111
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
120112
assertTensorAlmostEqual(self, reg, expected_reg)
121113

@@ -143,18 +135,10 @@ def test_bcstg_2d_input(self) -> None:
143135
gated_input, reg = bcstg(input_tensor)
144136

145137
expected_reg = 4.9903
146-
expected_gated_input = []
147-
148-
if self.testing_device == "cpu":
149-
expected_gated_input = [
150-
[[0.0000, 0.0990], [0.0261, 0.2431], [0.0551, 0.3863]],
151-
[[0.0476, 0.6177], [0.5400, 0.1530], [0.0984, 0.8013]],
152-
]
153-
elif self.testing_device == "cuda":
154-
expected_gated_input = [
155-
[[0.0000, 0.0985], [0.1149, 0.2331], [0.0486, 0.5000]],
156-
[[0.1840, 0.1571], [0.4612, 0.7937], [0.2975, 0.7393]],
157-
]
138+
expected_gated_input = [
139+
[[0.0000, 0.0990], [0.0261, 0.2431], [0.0551, 0.3863]],
140+
[[0.0476, 0.6177], [0.5400, 0.1530], [0.0984, 0.8013]],
141+
]
158142

159143
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
160144
assertTensorAlmostEqual(self, reg, expected_reg)
@@ -207,18 +191,11 @@ def test_bcstg_2d_input_with_mask(self) -> None:
207191
gated_input, reg = bcstg(input_tensor)
208192
expected_reg = 2.4947
209193

210-
if self.testing_device == "cpu":
211-
expected_gated_input = [
212-
[[0.0000, 0.0212], [0.0424, 0.0636], [0.3191, 0.4730]],
213-
[[0.3678, 0.6568], [0.7507, 0.8445], [0.6130, 1.0861]],
214-
]
215-
elif self.testing_device == "cuda":
216-
expected_gated_input = [
217-
[[0.0000, 0.0985], [0.1971, 0.2956], [0.0000, 0.2872]],
218-
[[0.4658, 0.0870], [0.0994, 0.1119], [0.7764, 1.1000]],
219-
]
194+
expected_gated_input = [
195+
[[0.0000, 0.0212], [0.0424, 0.0636], [0.3191, 0.4730]],
196+
[[0.3678, 0.6568], [0.7507, 0.8445], [0.6130, 1.0861]],
197+
]
220198

221-
# pyre-fixme[61]: `expected_gated_input` is undefined, or not always defined.
222199
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
223200
assertTensorAlmostEqual(self, reg, expected_reg)
224201

tests/module/test_binary_concrete_stochastic_gates_cuda.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,49 @@
33

44
# pyre-strict
55

6+
import unittest
7+
from unittest.mock import patch
8+
9+
import torch
10+
from captum.module.binary_concrete_stochastic_gates import BinaryConcreteStochasticGates
11+
from torch import Tensor
12+
613
from .test_binary_concrete_stochastic_gates import TestBinaryConcreteStochasticGates
714

815

9-
class TestBinaryConcreteStochasticGatesCUDA(TestBinaryConcreteStochasticGates):
16+
# CUDA RNG produces different sequences on different GPU architectures
17+
# (e.g. V100 vs A100 vs H100) even with the same seed, causing flaky
18+
# tests. By generating uniform samples on CPU and moving to the device,
19+
# tests get consistent results regardless of which GPU type runs them.
20+
def _cpu_rng_sample(self: BinaryConcreteStochasticGates, batch_size: int) -> Tensor:
21+
if self.training:
22+
u = torch.empty(batch_size, self.n_gates)
23+
u.uniform_(self.eps, 1 - self.eps)
24+
u = u.to(self.log_alpha_param.device)
25+
s = torch.sigmoid((torch.logit(u) + self.log_alpha_param) / self.temperature)
26+
else:
27+
s = torch.sigmoid(self.log_alpha_param)
28+
s = s.expand(batch_size, self.n_gates)
29+
30+
s_bar = s * (self.upper_bound - self.lower_bound) + self.lower_bound
31+
return s_bar
32+
33+
34+
class TestBinaryConcreteStochasticGatesCUDA(
35+
TestBinaryConcreteStochasticGates,
36+
):
1037
testing_device: str = "cuda"
38+
39+
def setUp(self) -> None:
40+
super().setUp()
41+
if not torch.cuda.is_available():
42+
raise unittest.SkipTest("Skipping GPU test since CUDA not available.")
43+
# pyre-fixme[8]: Attribute has type
44+
# `BoundMethod[..., Tensor]`; used as `(...) -> Tensor`.
45+
patcher = patch.object(
46+
BinaryConcreteStochasticGates,
47+
"_sample_gate_values",
48+
_cpu_rng_sample,
49+
)
50+
patcher.start()
51+
self.addCleanup(patcher.stop)

tests/module/test_gaussian_stochastic_gates.py

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,8 @@ def test_gstg_1d_input(self) -> None:
3333

3434
gated_input, reg = gstg(input_tensor)
3535
expected_reg = 2.5213
36+
expected_gated_input = [[0.0000, 0.0198, 0.1483], [0.1848, 0.3402, 0.1782]]
3637

37-
if self.testing_device == "cpu":
38-
expected_gated_input = [[0.0000, 0.0198, 0.1483], [0.1848, 0.3402, 0.1782]]
39-
elif self.testing_device == "cuda":
40-
expected_gated_input = [[0.0000, 0.0788, 0.0470], [0.0134, 0.0000, 0.1884]]
41-
42-
# pyre-fixme[61]: `expected_gated_input` is undefined, or not always defined.
4338
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
4439
assertTensorAlmostEqual(self, reg, expected_reg)
4540

@@ -90,13 +85,8 @@ def test_gstg_1d_input_with_mask(self) -> None:
9085

9186
gated_input, reg = gstg(input_tensor)
9287
expected_reg = 1.6849
88+
expected_gated_input = [[0.0000, 0.0000, 0.1225], [0.0583, 0.0777, 0.3779]]
9389

94-
if self.testing_device == "cpu":
95-
expected_gated_input = [[0.0000, 0.0000, 0.1225], [0.0583, 0.0777, 0.3779]]
96-
elif self.testing_device == "cuda":
97-
expected_gated_input = [[0.0000, 0.0000, 0.1577], [0.0736, 0.0981, 0.0242]]
98-
99-
# pyre-fixme[61]: `expected_gated_input` is undefined, or not always defined.
10090
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
10191
assertTensorAlmostEqual(self, reg, expected_reg)
10292

@@ -137,19 +127,11 @@ def test_gstg_2d_input(self) -> None:
137127

138128
gated_input, reg = gstg(input_tensor)
139129
expected_reg = 5.0458
130+
expected_gated_input = [
131+
[[0.0000, 0.0851], [0.0713, 0.3000], [0.2180, 0.1878]],
132+
[[0.2538, 0.0000], [0.3391, 0.8501], [0.3633, 0.8913]],
133+
]
140134

141-
if self.testing_device == "cpu":
142-
expected_gated_input = [
143-
[[0.0000, 0.0851], [0.0713, 0.3000], [0.2180, 0.1878]],
144-
[[0.2538, 0.0000], [0.3391, 0.8501], [0.3633, 0.8913]],
145-
]
146-
elif self.testing_device == "cuda":
147-
expected_gated_input = [
148-
[[0.0000, 0.0788], [0.0470, 0.0139], [0.0000, 0.1960]],
149-
[[0.0000, 0.7000], [0.1052, 0.2120], [0.5978, 0.0166]],
150-
]
151-
152-
# pyre-fixme[61]: `expected_gated_input` is undefined, or not always defined.
153135
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
154136
assertTensorAlmostEqual(self, reg, expected_reg)
155137

@@ -200,19 +182,11 @@ def test_gstg_2d_input_with_mask(self) -> None:
200182

201183
gated_input, reg = gstg(input_tensor)
202184
expected_reg = 2.5213
185+
expected_gated_input = [
186+
[[0.0000, 0.0198], [0.0396, 0.0594], [0.2435, 0.3708]],
187+
[[0.3696, 0.5954], [0.6805, 0.7655], [0.6159, 0.3921]],
188+
]
203189

204-
if self.testing_device == "cpu":
205-
expected_gated_input = [
206-
[[0.0000, 0.0198], [0.0396, 0.0594], [0.2435, 0.3708]],
207-
[[0.3696, 0.5954], [0.6805, 0.7655], [0.6159, 0.3921]],
208-
]
209-
elif self.testing_device == "cuda":
210-
expected_gated_input = [
211-
[[0.0000, 0.0788], [0.1577, 0.2365], [0.0000, 0.1174]],
212-
[[0.0269, 0.0000], [0.0000, 0.0000], [0.0448, 0.4145]],
213-
]
214-
215-
# pyre-fixme[61]: `expected_gated_input` is undefined, or not always defined.
216190
assertTensorAlmostEqual(self, gated_input, expected_gated_input, mode="max")
217191
assertTensorAlmostEqual(self, reg, expected_reg)
218192

tests/module/test_gaussian_stochastic_gates_cuda.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,42 @@
33

44
# pyre-strict
55

6+
import unittest
7+
from unittest.mock import patch
8+
9+
import torch
10+
from captum.module.gaussian_stochastic_gates import GaussianStochasticGates
11+
from torch import Tensor
12+
613
from .test_gaussian_stochastic_gates import TestGaussianStochasticGates
714

815

16+
# CUDA RNG produces different sequences on different GPU architectures
17+
# (e.g. V100 vs A100 vs H100) even with the same seed, causing flaky tests.
18+
# By generating noise on CPU (where torch.manual_seed is deterministic across
19+
# all hardware) and moving to the device, tests get consistent results
20+
# regardless of which GPU type runs them in CI.
21+
def _cpu_rng_sample(self: GaussianStochasticGates, batch_size: int) -> Tensor:
22+
if self.training:
23+
n = torch.empty(batch_size, self.n_gates)
24+
n.normal_(mean=0, std=self.std)
25+
return self.mu + n.to(self.mu.device)
26+
return self.mu.expand(batch_size, self.n_gates)
27+
28+
929
class TestGaussianStochasticGatesCUDA(TestGaussianStochasticGates):
1030
testing_device: str = "cuda"
31+
32+
def setUp(self) -> None:
33+
super().setUp()
34+
if not torch.cuda.is_available():
35+
raise unittest.SkipTest("Skipping GPU test since CUDA not available.")
36+
# pyre-fixme[8]: Attribute has type
37+
# `BoundMethod[..., Tensor]`; used as `(...) -> Tensor`.
38+
patcher = patch.object(
39+
GaussianStochasticGates,
40+
"_sample_gate_values",
41+
_cpu_rng_sample,
42+
)
43+
patcher.start()
44+
self.addCleanup(patcher.stop)

0 commit comments

Comments
 (0)