File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
colossalai/zero/low_level/bookkeeping Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -78,13 +78,13 @@ def build_grad_in_bucket(self):
78
78
}
79
79
"""
80
80
for param , padding_size in zip (self ._param_list , self ._padding_size ):
81
- grad = param .grad .clone (). detach ().flatten ()
81
+ grad = param .grad .detach ().flatten ()
82
82
if padding_size > 0 :
83
83
with torch .no_grad ():
84
84
grad = torch .nn .functional .pad (grad .view (- 1 ), [0 , padding_size ])
85
85
grad_list = grad .split (grad .numel () // self ._world_size )
86
86
for rank in range (self ._world_size ):
87
- grad_current_rank = grad_list [rank ].clone (). detach ()
87
+ grad_current_rank = grad_list [rank ].detach ()
88
88
self .grad_to_param_mapping [id (grad_current_rank )] = id (param )
89
89
self ._grad_in_bucket [rank ].append (grad_current_rank )
90
90
param .grad = None
@@ -110,7 +110,7 @@ def get_flatten_grad(self) -> Tensor:
110
110
111
111
flat_grad = []
112
112
for grad_list in self ._grad_in_bucket .values ():
113
- flat_grad .append ( _flatten_dense_tensors ( grad_list ) )
113
+ flat_grad .extend ( grad_list )
114
114
flat_grad = _flatten_dense_tensors (flat_grad )
115
115
return flat_grad
116
116
You can’t perform that action at this time.
0 commit comments