Skip to content

Commit e8ee704

Browse files
fix sharding stage3 bug (#60085) (#60106)
1 parent 1c0ffeb commit e8ee704

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,32 @@
3131
from .group_sharded_utils import GroupShardedClipGrad, Type, device_guard
3232

3333

34+
class OrderedSet:
35+
def __init__(self, iterable=None):
36+
self._data = OrderedDict.fromkeys(iterable or [])
37+
38+
def __contains__(self, item):
39+
return item in self._data
40+
41+
def __iter__(self):
42+
return iter(self._data)
43+
44+
def __len__(self):
45+
return len(self._data)
46+
47+
def add(self, item):
48+
self._data[item] = None
49+
50+
def discard(self, item):
51+
self._data.pop(item, None)
52+
53+
def update(self, iterable):
54+
self._data.update((item, None) for item in iterable)
55+
56+
def __repr__(self):
57+
return f"{self.__class__.__name__}({list(self._data)})"
58+
59+
3460
def _all_gather(tensor, buffer_size, group):
3561
"""
3662
The main difference with paddle.distributed.all_gather:
@@ -148,7 +174,7 @@ def __init__(
148174
{}
149175
) # {param.name: [(start0, end0),(start1, end1), ...]}
150176
self._trainable_params = {} # {id(layer): [trainable_params]}
151-
self._unslice_params = set() # param's numel <= segment_size
177+
self._unslice_params = OrderedSet() # param's numel <= segment_size
152178
self._unslice_params2align = {} # {param.name: param's align}
153179
self._grad_storages = {} # {param.dtype: GradStorage}
154180

0 commit comments

Comments
 (0)