Skip to content

Commit 6dfb7ac

Browse files
authored
Black reformatting
2 parents 2486c58 + 202629d commit 6dfb7ac

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,9 @@ def compute_gradient_magnitude(
226226
if self._settings.use_actions:
227227
policy_action = self.get_action_input(policy_batch)
228228
expert_action = self.get_action_input(expert_batch)
229-
action_epsilon = torch.rand(policy_action.shape, device=policy_action.device)
229+
action_epsilon = torch.rand(
230+
policy_action.shape, device=policy_action.device
231+
)
230232
policy_dones = torch.as_tensor(
231233
policy_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
232234
).unsqueeze(1)

ml-agents/mlagents/trainers/torch_entities/networks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@ def total_goal_enc_size(self) -> int:
8686
def update_normalization(self, buffer: AgentBuffer) -> None:
8787
obs = ObsUtil.from_buffer(buffer, len(self.processors))
8888
for vec_input, enc in zip(obs, self.processors):
89-
if isinstance(enc, VectorInput):
90-
enc.update_normalization(
91-
torch.as_tensor(vec_input.to_ndarray(), device=default_device())
92-
)
89+
if isinstance(enc, VectorInput):
90+
enc.update_normalization(
91+
torch.as_tensor(vec_input.to_ndarray(), device=default_device())
92+
)
9393

9494
def copy_normalization(self, other_encoder: "ObservationEncoder") -> None:
9595
if self.normalize:

ml-agents/mlagents/trainers/torch_entities/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,7 @@ def list_to_tensor(
234234
calling as_tensor on the list directly.
235235
"""
236236
device = default_device()
237-
return torch.as_tensor(
238-
np.asanyarray(ndarray_list), dtype=dtype, device=device
239-
)
237+
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype, device=device)
240238

241239
@staticmethod
242240
def list_to_tensor_list(

0 commit comments

Comments
 (0)