Skip to content

Commit d93dd9e

Browse files
authored
Stop gradient propagation during network update for TeacherStudent (#82)
As in #66
1 parent a5e2e42 commit d93dd9e

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

rsl_rl/modules/student_teacher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def entropy(self):
9292

9393
def update_distribution(self, observations):
9494
mean = self.student(observations)
95-
self.distribution = Normal(mean, mean * 0.0 + self.std)
95+
std = self.std.expand_as(mean)
96+
self.distribution = Normal(mean, std)
9697

9798
def act(self, observations):
9899
self.update_distribution(observations)

0 commit comments

Comments
 (0)