@@ -11616,7 +11616,7 @@ def _make_transform_env(self, out_key, base_env):
11616
11616
transform = KLRewardTransform(actor, out_keys=out_key)
11617
11617
return Compose(
11618
11618
TensorDictPrimer(
11619
- sample_log_prob =Unbounded(shape=base_env.action_spec.shape[:-1]),
11619
+ action_log_prob =Unbounded(shape=base_env.action_spec.shape[:-1]),
11620
11620
shape=base_env.shape,
11621
11621
),
11622
11622
transform,
@@ -11640,7 +11640,7 @@ def test_transform_no_env(self, in_key, out_key):
11640
11640
{
11641
11641
"action": torch.randn(*batch, 7),
11642
11642
"observation": torch.randn(*batch, 7),
11643
- "sample_log_prob ": torch.randn(*batch),
11643
+ "action_log_prob ": torch.randn(*batch),
11644
11644
},
11645
11645
batch,
11646
11646
)
@@ -11658,7 +11658,7 @@ def test_transform_compose(self):
11658
11658
"action": torch.randn(*batch, 7),
11659
11659
"observation": torch.randn(*batch, 7),
11660
11660
"next": {t[0].in_keys[0]: torch.zeros(*batch, 1)},
11661
- "sample_log_prob ": torch.randn(*batch),
11661
+ "action_log_prob ": torch.randn(*batch),
11662
11662
},
11663
11663
batch,
11664
11664
)
@@ -11678,7 +11678,7 @@ def test_transform_env(self, out_key):
11678
11678
base_env = self.envclass()
11679
11679
torch.manual_seed(0)
11680
11680
actor = self._make_actor()
11681
- # we need to patch the env and create a sample_log_prob spec to make check_env_specs happy
11681
+ # we need to patch the env and create a action_log_prob spec to make check_env_specs happy
11682
11682
env = TransformedEnv(
11683
11683
base_env,
11684
11684
Compose(
@@ -11711,7 +11711,7 @@ def update(x):
11711
11711
@pytest.mark.parametrize("out_key", [None, "some_stuff", ["some_stuff"]])
11712
11712
def test_single_trans_env_check(self, out_key):
11713
11713
base_env = self.envclass()
11714
- # we need to patch the env and create a sample_log_prob spec to make check_env_specs happy
11714
+ # we need to patch the env and create a action_log_prob spec to make check_env_specs happy
11715
11715
env = TransformedEnv(base_env, self._make_transform_env(out_key, base_env))
11716
11716
check_env_specs(env)
11717
11717
@@ -11776,7 +11776,7 @@ def test_transform_model(self):
11776
11776
"action": torch.randn(*batch, 7),
11777
11777
"observation": torch.randn(*batch, 7),
11778
11778
"next": {t.in_keys[0]: torch.zeros(*batch, 1)},
11779
- "sample_log_prob ": torch.randn(*batch),
11779
+ "action_log_prob ": torch.randn(*batch),
11780
11780
},
11781
11781
batch,
11782
11782
)
@@ -11796,7 +11796,7 @@ def test_transform_rb(self, rbclass):
11796
11796
"action": torch.randn(*batch, 7),
11797
11797
"observation": torch.randn(*batch, 7),
11798
11798
"next": {t.in_keys[0]: torch.zeros(*batch, 1)},
11799
- "sample_log_prob ": torch.randn(*batch),
11799
+ "action_log_prob ": torch.randn(*batch),
11800
11800
},
11801
11801
batch,
11802
11802
)
0 commit comments