Skip to content

Commit 30f648f

Browse files
committed
feature: support sparse optimizer
1 parent 0f8b1e6 commit 30f648f

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

pytorch_optimizer/madgrad.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,24 +127,24 @@ def step(self, closure: CLOSURE = None) -> LOSS:
127127
s_masked = s.sparse_mask(grad)
128128

129129
# Compute x_0 from other known quantities
130-
rms_masked_values = grad_sum_sq_masked.data.pow(1 / 3).add_(eps)
131-
x0_masked_values = p_masked.data.addcdiv(s_masked.data, rms_masked_values, value=1)
130+
rms_masked_values = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
131+
x0_masked_values = p_masked._values().addcdiv(s_masked._values(), rms_masked_values, value=1)
132132

133133
# Dense + sparse op
134134
grad_sq = grad * grad
135135
grad_sum_sq.add_(grad_sq, alpha=_lambda)
136136
grad_sum_sq_masked.add_(grad_sq, alpha=_lambda)
137137

138-
rms_masked_values = grad_sum_sq_masked.data.pow_(1 / 3).add_(eps)
138+
rms_masked_values = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
139139

140140
s.add_(grad, alpha=_lambda)
141-
s_masked.data.add_(grad.data, alpha=_lambda)
141+
s_masked._values().add_(grad._values(), alpha=_lambda)
142142

143143
# update masked copy of p
144-
p_kp1_masked_values = x0_masked_values.addcdiv(s_masked.data, rms_masked_values, value=-1)
144+
p_kp1_masked_values = x0_masked_values.addcdiv(s_masked._values(), rms_masked_values, value=-1)
145145

146146
# Copy updated masked p to dense p using an add operation
147-
p_masked.data.add_(p_kp1_masked_values, alpha=-1)
147+
p_masked._values().add_(p_kp1_masked_values, alpha=-1)
148148
p.data.add_(p_masked, alpha=-1)
149149
else:
150150
if momentum == 0:

0 commit comments

Comments
 (0)