@@ -96,6 +96,7 @@ def backward(ctx, grad_output):
9696 use_zbv = ctx .use_zbv
9797
9898 # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
99+ weight_origin = weight
99100 weight = weight .view (weight .shape )
100101 if bias is not None :
101102 bias = bias .view (bias .shape )
@@ -130,7 +131,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
130131 WeightGradStore .put (
131132 total_input ,
132133 grad_output ,
133- weight ,
134+ ( weight , weight_origin ) ,
134135 functools .partial (
135136 execute_w_pass_grad_accum ,
136137 wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32 ,
@@ -141,7 +142,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
141142 WeightGradStore .put (
142143 total_input ,
143144 grad_output ,
144- weight ,
145+ ( weight , weight_origin ) ,
145146 functools .partial (
146147 execute_w_pass_grad_accum ,
147148 wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp16 ,
@@ -164,7 +165,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
164165 WeightGradStore .put (
165166 total_input ,
166167 grad_output ,
167- weight ,
168+ ( weight , weight_origin ) ,
168169 functools .partial (
169170 execute_w_pass ,
170171 wgrad_gemm_func = torch .matmul ,
@@ -212,6 +213,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
212213 return wgrad_gemm_func (_input_ .t (), _grad_output_ )
213214
214215 # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
216+ weight_origin = weight
215217 weight = weight .view (weight .shape )
216218 if bias is not None :
217219 bias = bias .view (bias .shape )
@@ -232,7 +234,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
232234 WeightGradStore .put (
233235 total_input ,
234236 grad_output ,
235- weight ,
237+ ( weight , weight_origin ) ,
236238 functools .partial (
237239 execute_w_pass_grad_accum ,
238240 wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32 ,
@@ -243,7 +245,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
243245 WeightGradStore .put (
244246 total_input ,
245247 grad_output ,
246- weight ,
248+ ( weight , weight_origin ) ,
247249 functools .partial (
248250 execute_w_pass_grad_accum ,
249251 wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp16 ,
@@ -266,7 +268,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
266268 WeightGradStore .put (
267269 total_input ,
268270 grad_output ,
269- weight ,
271+ ( weight , weight_origin ) ,
270272 functools .partial (
271273 execute_w_pass ,
272274 wgrad_gemm_func = torch .matmul ,
@@ -1026,6 +1028,7 @@ def backward(ctx, grad_output):
10261028 use_zbv = ctx .use_zbv
10271029
10281030 # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
1031+ weight_origin = weight
10291032 weight = weight .view (weight .shape )
10301033 if use_bias :
10311034 bias = bias .view (bias .shape )
@@ -1064,7 +1067,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
10641067 WeightGradStore .put (
10651068 total_input ,
10661069 grad_output ,
1067- weight ,
1070+ ( weight , weight_origin ) ,
10681071 functools .partial (
10691072 execute_w_pass_grad_accum ,
10701073 wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp32 ,
@@ -1075,7 +1078,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
10751078 WeightGradStore .put (
10761079 total_input ,
10771080 grad_output ,
1078- weight ,
1081+ ( weight , weight_origin ) ,
10791082 functools .partial (
10801083 execute_w_pass_grad_accum ,
10811084 wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda .wgrad_gemm_accum_fp16 ,
@@ -1098,7 +1101,7 @@ def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_f
10981101 WeightGradStore .put (
10991102 total_input ,
11001103 grad_output ,
1101- weight ,
1104+ ( weight , weight_origin ) ,
11021105 functools .partial (
11031106 execute_w_pass ,
11041107 wgrad_gemm_func = torch .matmul ,
0 commit comments