Skip to content

Commit 4b3f9e5

Browse files
authored
fix params with only 1 dim (#15828) (#15832)
* fix params with only 1 dim * test=develop
1 parent 2307baf commit 4b3f9e5

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

python/paddle/fluid/io.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,10 @@ def __load_persistable_vars(executor, dirname, need_load_vars):
766766
dtype=slice_var.dtype,
767767
persistable=True)
768768

769-
dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:])
769+
dim1_flatten = 1
770+
if len(slice.shape) >= 2:
771+
dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:])
772+
770773
start = int(offset / dim1_flatten)
771774
end = int(offset / dim1_flatten + slice.shape[0])
772775

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,11 @@ def _get_slice_var_info(self, slice_var):
10201020
skip_dim0 = 0
10211021
slice_vars = self.param_var_mapping[orig_var_name]
10221022

1023-
orig_dim1_flatten = reduce(lambda x, y: x * y, slice_vars[0].shape[1:])
1023+
orig_dim1_flatten = 1
1024+
1025+
if len(slice_vars[0].shape) >= 2:
1026+
orig_dim1_flatten = reduce(lambda x, y: x * y,
1027+
slice_vars[0].shape[1:])
10241028

10251029
for slice_var in slice_vars[:block_idx]:
10261030
skip_dim0 += slice_var.shape[0]

0 commit comments

Comments
 (0)