Skip to content

Commit 3f82ce1

Browse files
authored
fix: coupling flows and CMs not working on GPU due to int type (#273)
Apparently, int32 variables are not transferred to the GPU, leading to problems with XLA. Changing the type declarations to int seems to fix the problem
1 parent 554ed13 commit 3f82ce1

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
self.s0 = float(s0)
8484
self.s1 = float(s1)
8585
# create variable that works with JIT compilation
86-
self.current_step = self.add_weight(name="current_step", initializer="zeros", trainable=False, dtype="int32")
86+
self.current_step = self.add_weight(name="current_step", initializer="zeros", trainable=False, dtype="int")
8787
self.current_step.assign(0)
8888

8989
self.seed_generator = keras.random.SeedGenerator()
@@ -258,7 +258,7 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr
258258
self.current_step.assign(ops.minimum(self.current_step, self.total_steps - 1))
259259

260260
discretization_index = ops.take(
261-
self.discretization_map, ops.cast(self._schedule_discretization(self.current_step), "int32")
261+
self.discretization_map, ops.cast(self._schedule_discretization(self.current_step), "int")
262262
)
263263
discretized_time = ops.take(self.discretized_times, discretization_index, axis=0)
264264

bayesflow/networks/coupling_flow/permutations/random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ def build(self, xz_shape: Shape, **kwargs) -> None:
1616
shape=(xz_shape[-1],),
1717
initializer=keras.initializers.Constant(forward_indices),
1818
trainable=False,
19-
dtype="int32",
19+
dtype="int",
2020
)
2121

2222
self.inverse_indices = self.add_weight(
2323
shape=(xz_shape[-1],),
2424
initializer=keras.initializers.Constant(inverse_indices),
2525
trainable=False,
26-
dtype="int32",
26+
dtype="int",
2727
)

bayesflow/networks/coupling_flow/permutations/swap.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ def build(self, xz_shape: Shape, **kwargs) -> None:
1616
shape=(xz_shape[-1],),
1717
initializer=keras.initializers.Constant(forward_indices),
1818
trainable=False,
19-
dtype="int32",
19+
dtype="int",
2020
)
2121

2222
self.inverse_indices = self.add_variable(
2323
shape=(xz_shape[-1],),
2424
initializer=keras.initializers.Constant(inverse_indices),
2525
trainable=False,
26-
dtype="int32",
26+
dtype="int",
2727
)

0 commit comments

Comments
 (0)