Skip to content

Commit 12ccbd1

Browse files
committed
fix: initialisation weirdness
1 parent 1f02571 commit 12ccbd1

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

pufferlib/models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,14 @@ def __init__(self, env, policy, input_size=128, hidden_size=128):
117117
if "bias" in name:
118118
nn.init.constant_(param, 0)
119119
elif "weight" in name and param.ndim >= 2:
120-
nn.init.orthogonal_(param, 1.0)
120+
if param.device.type == 'mps':
121+
# Apple MPS does not support orthogonal
122+
123+
param.to(device='cpu')
124+
nn.init.orthogonal_(param, 1.0)
125+
param.to(device=param.device)
126+
else:
127+
nn.init.orthogonal_(param, 1.0)
121128

122129
self.lstm = nn.LSTM(input_size, hidden_size)
123130

pufferlib/pytorch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,13 @@ def _flattened_tensor_size(native_dtype):
164164

165165
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
166166
"""CleanRL's default layer initialization"""
167-
torch.nn.init.orthogonal_(layer.weight, std)
167+
if layer.weight.device.type == 'mps':
168+
# Apple MPS does not support orthogonal
169+
layer.weight.to(device='cpu')
170+
nn.init.orthogonal_(layer.weight, std)
171+
layer.weight.to(device=layer.device)
172+
else:
173+
nn.init.orthogonal_(layer.weight, std)
168174
torch.nn.init.constant_(layer.bias, bias_const)
169175
return layer
170176

0 commit comments

Comments
 (0)