Skip to content

Commit 9ce6902

Browse files
committed
Update
[ghstack-poisoned]
2 parents cb207f2 + 2b60639 commit 9ce6902

File tree

5 files changed

+223
-22
lines changed

5 files changed

+223
-22
lines changed

test/test_distributions.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import argparse
88
import importlib.util
9+
from functools import partial
910

1011
import pytest
1112
import torch
@@ -16,6 +17,7 @@
1617
from torch import autograd, nn
1718
from torch.utils._pytree import tree_map
1819
from torchrl.modules import (
20+
IndependentNormal,
1921
OneHotCategorical,
2022
OneHotOrdinal,
2123
Ordinal,
@@ -169,6 +171,184 @@ def test_tanhnormal_event_dims(self, event_dims):
169171
exp_shape,
170172
)
171173

174+
@pytest.mark.parametrize("device", get_default_devices())
175+
@pytest.mark.parametrize(
176+
"callable_scale",
177+
[torch.ones_like, partial(torch.full_like, fill_value=0.5)],
178+
ids=["ones_like", "full_like_partial"],
179+
)
180+
def test_tanhnormal_callable_scale(self, device, callable_scale):
181+
"""Test that TanhNormal supports callable scale for compile-friendliness.
182+
183+
Using a callable scale (e.g., torch.ones_like or partial(torch.full_like, fill_value=...))
184+
avoids explicit device transfers and prevents graph breaks in torch.compile.
185+
"""
186+
torch.manual_seed(0)
187+
loc = torch.randn(3, 4, device=device)
188+
189+
# Create distribution with callable scale
190+
dist = TanhNormal(loc=loc, scale=callable_scale, low=-1, high=1)
191+
192+
# Check that the scale was properly resolved
193+
expected_scale = callable_scale(loc)
194+
torch.testing.assert_close(dist.scale, expected_scale)
195+
196+
# Test sampling
197+
sample = dist.sample()
198+
assert sample.shape == loc.shape
199+
assert sample.device == loc.device
200+
assert (sample >= -1).all()
201+
assert (sample <= 1).all()
202+
203+
# Test log_prob
204+
log_prob = dist.log_prob(sample)
205+
assert torch.isfinite(log_prob).all()
206+
207+
# Test rsample with gradient
208+
loc_grad = torch.randn(3, 4, device=device, requires_grad=True)
209+
dist_grad = TanhNormal(loc=loc_grad, scale=callable_scale, low=-1, high=1)
210+
sample_grad = dist_grad.rsample()
211+
loss = sample_grad.sum()
212+
loss.backward()
213+
assert loc_grad.grad is not None
214+
assert torch.isfinite(loc_grad.grad).all()
215+
216+
@pytest.mark.parametrize("device", get_default_devices())
217+
def test_tanhnormal_callable_scale_update(self, device):
218+
"""Test that TanhNormal.update() works with callable scale."""
219+
torch.manual_seed(0)
220+
loc = torch.randn(3, 4, device=device)
221+
callable_scale = torch.ones_like
222+
223+
dist = TanhNormal(loc=loc, scale=callable_scale, low=-1, high=1)
224+
225+
# Update with new loc and callable scale
226+
new_loc = torch.randn(3, 4, device=device)
227+
dist.update(new_loc, callable_scale)
228+
229+
# Check that scale was properly resolved
230+
torch.testing.assert_close(dist.scale, torch.ones_like(new_loc))
231+
232+
# Verify distribution works after update
233+
sample = dist.sample()
234+
assert sample.shape == new_loc.shape
235+
assert torch.isfinite(dist.log_prob(sample)).all()
236+
237+
238+
class TestIndependentNormal:
239+
@pytest.mark.parametrize("device", get_default_devices())
240+
@pytest.mark.parametrize(
241+
"callable_scale",
242+
[torch.ones_like, partial(torch.full_like, fill_value=0.5)],
243+
ids=["ones_like", "full_like_partial"],
244+
)
245+
def test_independentnormal_callable_scale(self, device, callable_scale):
246+
"""Test that IndependentNormal supports callable scale for compile-friendliness.
247+
248+
Using a callable scale (e.g., torch.ones_like or partial(torch.full_like, fill_value=...))
249+
avoids explicit device transfers and prevents graph breaks in torch.compile.
250+
"""
251+
torch.manual_seed(0)
252+
loc = torch.randn(3, 4, device=device)
253+
254+
# Create distribution with callable scale
255+
dist = IndependentNormal(loc=loc, scale=callable_scale)
256+
257+
# Check that the scale was properly resolved
258+
expected_scale = callable_scale(loc)
259+
torch.testing.assert_close(dist.base_dist.scale, expected_scale)
260+
261+
# Test sampling
262+
sample = dist.sample()
263+
assert sample.shape == loc.shape
264+
assert sample.device == loc.device
265+
266+
# Test log_prob
267+
log_prob = dist.log_prob(sample)
268+
assert torch.isfinite(log_prob).all()
269+
270+
# Test rsample with gradient
271+
loc_grad = torch.randn(3, 4, device=device, requires_grad=True)
272+
dist_grad = IndependentNormal(loc=loc_grad, scale=callable_scale)
273+
sample_grad = dist_grad.rsample()
274+
loss = sample_grad.sum()
275+
loss.backward()
276+
assert loc_grad.grad is not None
277+
assert torch.isfinite(loc_grad.grad).all()
278+
279+
@pytest.mark.parametrize("device", get_default_devices())
280+
def test_independentnormal_callable_scale_update(self, device):
281+
"""Test that IndependentNormal.update() works with callable scale."""
282+
torch.manual_seed(0)
283+
loc = torch.randn(3, 4, device=device)
284+
callable_scale = torch.ones_like
285+
286+
dist = IndependentNormal(loc=loc, scale=callable_scale)
287+
288+
# Update with new loc and callable scale
289+
new_loc = torch.randn(3, 4, device=device)
290+
dist.update(new_loc, callable_scale)
291+
292+
# Check that scale was properly resolved
293+
torch.testing.assert_close(dist.base_dist.scale, torch.ones_like(new_loc))
294+
295+
# Verify distribution works after update
296+
sample = dist.sample()
297+
assert sample.shape == new_loc.shape
298+
assert torch.isfinite(dist.log_prob(sample)).all()
299+
300+
@pytest.mark.parametrize("device", get_default_devices())
301+
@pytest.mark.parametrize("scale_type", ["tensor", "float", "callable"])
302+
def test_independentnormal_scale_types(self, device, scale_type):
303+
"""Test that IndependentNormal supports all scale types: tensor, float, callable."""
304+
torch.manual_seed(0)
305+
loc = torch.randn(3, 4, device=device)
306+
307+
if scale_type == "tensor":
308+
scale = torch.ones(3, 4, device=device)
309+
elif scale_type == "float":
310+
scale = 1.0
311+
else: # callable
312+
scale = torch.ones_like
313+
314+
dist = IndependentNormal(loc=loc, scale=scale)
315+
316+
# Test sampling
317+
sample = dist.sample()
318+
assert sample.shape == loc.shape
319+
assert sample.device == loc.device
320+
321+
# Test log_prob
322+
log_prob = dist.log_prob(sample)
323+
assert torch.isfinite(log_prob).all()
324+
325+
@pytest.mark.parametrize("device", get_default_devices())
326+
@pytest.mark.parametrize("scale_type", ["tensor", "float", "callable"])
327+
def test_tanhnormal_scale_types(self, device, scale_type):
328+
"""Test that TanhNormal supports all scale types: tensor, float, callable."""
329+
torch.manual_seed(0)
330+
loc = torch.randn(3, 4, device=device)
331+
332+
if scale_type == "tensor":
333+
scale = torch.ones(3, 4, device=device)
334+
elif scale_type == "float":
335+
scale = 1.0
336+
else: # callable
337+
scale = torch.ones_like
338+
339+
dist = TanhNormal(loc=loc, scale=scale, low=-1, high=1)
340+
341+
# Test sampling
342+
sample = dist.sample()
343+
assert sample.shape == loc.shape
344+
assert sample.device == loc.device
345+
assert (sample >= -1).all()
346+
assert (sample <= 1).all()
347+
348+
# Test log_prob
349+
log_prob = dist.log_prob(sample)
350+
assert torch.isfinite(log_prob).all()
351+
172352

173353
class TestTruncatedNormal:
174354
@pytest.mark.parametrize(

torchrl/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
import os
65
import warnings
76
import weakref
87
from warnings import warn

torchrl/data/replay_buffers/samplers.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,30 +1556,14 @@ def _get_index(
15561556
out_of_traj = relative_starts < 0
15571557
if out_of_traj.any():
15581558
# a negative start means sampling fewer elements
1559-
# Convert seq_length to tensor to avoid torch.compile inductor C++ codegen
1560-
# bug with mixed scalar/tensor int64 in blendv operations (see PyTorch #xyz)
1561-
seq_length_t = torch.as_tensor(
1562-
seq_length,
1563-
dtype=relative_starts.dtype,
1564-
device=relative_starts.device,
1565-
)
15661559
seq_length = torch.where(
1567-
~out_of_traj, seq_length_t, seq_length_t + relative_starts
1568-
)
1569-
relative_starts = torch.where(
1570-
~out_of_traj, relative_starts, torch.zeros_like(relative_starts)
1560+
~out_of_traj, seq_length, seq_length + relative_starts
15711561
)
1562+
relative_starts = torch.where(~out_of_traj, relative_starts, 0)
15721563
if self.span[1]:
15731564
out_of_traj = relative_starts + seq_length > lengths[traj_idx]
15741565
if out_of_traj.any():
15751566
# a negative start means sampling fewer elements
1576-
# Convert seq_length to tensor if it's still a scalar
1577-
if not isinstance(seq_length, torch.Tensor):
1578-
seq_length = torch.as_tensor(
1579-
seq_length,
1580-
dtype=relative_starts.dtype,
1581-
device=relative_starts.device,
1582-
)
15831567
seq_length = torch.minimum(
15841568
seq_length, lengths[traj_idx] - relative_starts
15851569
)

torchrl/modules/distributions/continuous.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@ class IndependentNormal(D.Independent):
5858
5959
Args:
6060
loc (torch.Tensor): normal distribution location parameter
61-
scale (torch.Tensor): normal distribution sigma parameter (squared root of variance)
61+
scale (torch.Tensor, float, or callable): normal distribution sigma parameter (squared root of variance).
62+
Can be a tensor, a float, or a callable that takes the ``loc`` tensor as input and returns the scale tensor.
63+
Using a callable (e.g., ``torch.ones_like`` or ``functools.partial(torch.full_like, fill_value=0.1)``)
64+
avoids explicit device transfers like ``torch.tensor(val, device=device)`` and prevents graph breaks
65+
in :func:`torch.compile`.
6266
upscale (torch.Tensor or number, optional): 'a' scaling factor in the formula:
6367
6468
.. math::
@@ -69,6 +73,20 @@ class IndependentNormal(D.Independent):
6973
tanh_loc (bool, optional): if ``False``, the above formula is used for
7074
the location scaling, otherwise the raw value
7175
is kept. Default is ``False``;
76+
77+
Example:
78+
>>> import torch
79+
>>> from functools import partial
80+
>>> from torchrl.modules.distributions import IndependentNormal
81+
>>> loc = torch.zeros(3, 4)
82+
>>> # Using a callable scale avoids device transfers and graph breaks in torch.compile
83+
>>> dist = IndependentNormal(loc, scale=torch.ones_like)
84+
>>> # For a custom scale value, use partial to create a callable
85+
>>> dist = IndependentNormal(loc, scale=partial(torch.full_like, fill_value=0.1))
86+
>>> sample = dist.sample()
87+
>>> sample.shape
88+
torch.Size([3, 4])
89+
7290
"""
7391

7492
num_params: int = 2
@@ -330,7 +348,11 @@ class TanhNormal(FasterTransformedDistribution):
330348
331349
Args:
332350
loc (torch.Tensor): normal distribution location parameter
333-
scale (torch.Tensor): normal distribution sigma parameter (squared root of variance)
351+
scale (torch.Tensor, float, or callable): normal distribution sigma parameter (squared root of variance).
352+
Can be a tensor, a float, or a callable that takes the ``loc`` tensor as input and returns the scale tensor.
353+
Using a callable (e.g., ``torch.ones_like`` or ``functools.partial(torch.full_like, fill_value=0.1)``)
354+
avoids explicit device transfers like ``torch.tensor(val, device=device)`` and prevents graph breaks
355+
in :func:`torch.compile`.
334356
upscale (torch.Tensor or number): 'a' scaling factor in the formula:
335357
336358
.. math::
@@ -345,6 +367,20 @@ class TanhNormal(FasterTransformedDistribution):
345367
value is kept. Default is ``False``;
346368
safe_tanh (bool, optional): if ``True``, the Tanh transform is done "safely", to avoid numerical overflows.
347369
This will currently break with :func:`torch.compile`.
370+
371+
Example:
372+
>>> import torch
373+
>>> from functools import partial
374+
>>> from torchrl.modules.distributions import TanhNormal
375+
>>> loc = torch.zeros(3, 4)
376+
>>> # Using a callable scale avoids device transfers and graph breaks in torch.compile
377+
>>> dist = TanhNormal(loc, scale=torch.ones_like)
378+
>>> # For a custom scale value, use partial to create a callable
379+
>>> dist = TanhNormal(loc, scale=partial(torch.full_like, fill_value=0.1))
380+
>>> sample = dist.sample()
381+
>>> sample.shape
382+
torch.Size([3, 4])
383+
348384
"""
349385

350386
arg_constraints = {

torchrl/modules/distributions/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def _cast_transform_device(transform, device):
3232
for attribute in dir(transform):
3333
value = getattr(transform, attribute)
3434
if isinstance(value, torch.Tensor):
35-
setattr(transform, attribute, value.to(device, non_blocking=_non_blocking))
35+
setattr(
36+
transform, attribute, value.to(device, non_blocking=_non_blocking)
37+
)
3638
return transform
3739
else:
3840
raise TypeError(

0 commit comments

Comments
 (0)