Skip to content

Commit e58582b

Browse files
authored
4th-batch-89-未对元素进行检查 (PaddlePaddle#75803)
1 parent bd295b5 commit e58582b

File tree

1 file changed

+11
-4
lines changed
  • python/paddle/distributed/auto_parallel

1 file changed

+11
-4
lines changed

python/paddle/distributed/auto_parallel/api.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,17 @@ def __init__(self, mesh, sharding_specs):
209209
), 'The dimension name in sharding_specs must be an instance of str.'
210210

211211
self._sharding_specs = sharding_specs
212-
dims_mapping = [
213-
mesh.dim_names.index(dim_name) if dim_name is not None else -1
214-
for dim_name in sharding_specs
215-
]
212+
dims_mapping = []
213+
for dim_name in sharding_specs:
214+
if dim_name is None:
215+
dims_mapping.append(-1)
216+
else:
217+
if dim_name not in mesh.dim_names:
218+
raise ValueError(
219+
f"Invalid sharding dimension '{dim_name}'. "
220+
f"Available dimensions in mesh are: {mesh.dim_names}."
221+
)
222+
dims_mapping.append(mesh.dim_names.index(dim_name))
216223

217224
# 2. init core.TensorDistAttr
218225
core.TensorDistAttr.__init__(self)

0 commit comments

Comments
 (0)