Skip to content

Commit 5ab5687

Browse files
committed
remove no necessary doc changes. test=develop
1 parent 6b854f3 commit 5ab5687

File tree

1 file changed

+198
-2
lines changed

1 file changed

+198
-2
lines changed

python/paddle/fluid/framework.py

Lines changed: 198 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,183 @@ def _set_error_clip(self, error_clip):
627627
"""
628628
self.error_clip = error_clip
629629

630+
def _slice_indices(self, slice, length):
631+
"""
632+
Reference implementation for the slice.indices method.
633+
"""
634+
# Compute step and length as integers.
635+
step = 1 if slice.step is None else slice.step
636+
637+
# Raise ValueError for negative length or zero step.
638+
if length < 0:
639+
raise ValueError("length should not be negative")
640+
if step == 0:
641+
raise ValueError("slice step cannot be zero")
642+
643+
# Find lower and upper bounds for start and stop.
644+
lower = -1 if step < 0 else 0
645+
upper = length - 1 if step < 0 else length
646+
647+
# Compute start.
648+
if slice.start is None:
649+
start = upper if step < 0 else lower
650+
else:
651+
start = slice.start
652+
start = max(start + length, lower) if start < 0 else min(start,
653+
upper)
654+
655+
# Compute stop.
656+
if slice.stop is None:
657+
stop = lower if step < 0 else upper
658+
else:
659+
stop = slice.stop
660+
stop = max(stop + length, lower) if stop < 0 else min(stop, upper)
661+
662+
return start, stop, step
663+
664+
def _detectEllipsis(self, item):
665+
has_ellipsis = False
666+
start = 0
667+
end = len(self.shape)
668+
for index, o in enumerate(item):
669+
if o is Ellipsis:
670+
if has_ellipsis:
671+
raise ValueError("Index can have one ellipsis only.")
672+
has_ellipsis = True
673+
start = index
674+
else:
675+
if has_ellipsis:
676+
end = index
677+
return has_ellipsis, start, end
678+
679+
def _reconstructSliceinfo(self, item):
680+
has_ellipsis, start, end = self._detectEllipsis(item)
681+
if has_ellipsis:
682+
newitem = []
683+
for i in range(start):
684+
newitem.append(item[i])
685+
for i in range(start, end):
686+
newitem.append(slice(None, None, None))
687+
for i in range(end, len(item)):
688+
newitem.append(item[i])
689+
return newitem
690+
else:
691+
return None
692+
693+
def _detectContinuesSlice(self, item):
694+
starts = []
695+
ends = []
696+
for index, o in enumerate(item):
697+
if isinstance(o, int):
698+
start = int(o)
699+
if (index > 0 and index >= self.shape[index]) \
700+
or (index < 0 and (index + self.shape[index]) < 0):
701+
raise IndexError("invalid index")
702+
start = max(start + self.shape[index], 0) if start < 0 else min(
703+
start, self.shape[index])
704+
starts.append(start)
705+
ends.append(start + 1)
706+
elif isinstance(o, slice):
707+
start, stop, step = self._slice_indices(o, self.shape[index])
708+
if step == 1 or step == -1:
709+
starts.append(start)
710+
ends.append(stop)
711+
else:
712+
return False, None
713+
else:
714+
raise IndexError("Valid index accept int or slice or ellipsis")
715+
return True, [starts, ends]
716+
717+
def _cloneVar(self, copy=False):
718+
if not copy:
719+
return self.block.create_var(
720+
name=unique_name.generate(".".join(self.name)),
721+
dtype=self.dtype,
722+
persistable=self.persistable,
723+
stop_gradient=self._stop_gradient, )
724+
else:
725+
return self
726+
727+
def _sliceVar(self, axes, starts, ends):
728+
new_var = self._cloneVar()
729+
self.block.append_op(
730+
type="slice",
731+
inputs={'Input': [self]},
732+
outputs={'Out': [new_var]},
733+
attrs={'axes': axes,
734+
'starts': starts,
735+
'ends': ends})
736+
return new_var
737+
738+
def _concatVar(self, inputs, axis):
739+
new_var = self._cloneVar()
740+
self.block.append_op(
741+
type="concat",
742+
inputs={'X': inputs},
743+
outputs={'Out': [new_var]},
744+
attrs={'axis': axis, })
745+
return new_var
746+
747+
def _sliceAndConcatVar(self, item, axis):
748+
if isinstance(item, slice):
749+
if self.shape[axis] < 0:
750+
return self._cloneVar(True)
751+
start, stop, step = self._slice_indices(item, self.shape[axis])
752+
if step == 1:
753+
return self._sliceVar([axis], [start], [stop])
754+
else:
755+
vars = []
756+
if step > 0:
757+
while start < stop:
758+
vars.append(
759+
self._sliceVar([axis], [start], [start + 1]))
760+
start += step
761+
else:
762+
while start > stop:
763+
vars.append(
764+
self._sliceVar([axis], [start], [start + 1]))
765+
start += step
766+
return self._concatVar(vars, axis)
767+
elif isinstance(item, int):
768+
if self.shape[axis] < 0:
769+
return self._cloneVar(True)
770+
index = int(item)
771+
if (index > 0 and index >= self.shape[axis])\
772+
or (index < 0 and (index + self.shape[axis]) < 0):
773+
raise IndexError("invalid index")
774+
return self._sliceVar([axis], [index], [index + 1])
775+
else:
776+
raise IndexError("Valid index accept int or slice or tuple")
777+
778+
def __getitem__(self, item):
779+
"""
780+
Slice the variable.
781+
782+
Args:
783+
item(int/slice/tuple) : the index.
784+
785+
Returns:
786+
Sliced variable
787+
"""
788+
new_var = None
789+
if isinstance(item, tuple):
790+
if len(item) > len(self.shape):
791+
raise IndexError("Too many indexes")
792+
newitem = self._reconstructSliceinfo(item) or item
793+
check, info = self._detectContinuesSlice(newitem)
794+
if check:
795+
starts = info[0]
796+
ends = info[1]
797+
axes = [i for i in range(len(starts))]
798+
return self._sliceVar(axes, starts, ends)
799+
else:
800+
new_var = self
801+
for index, o in enumerate(newitem):
802+
new_var = new_var._sliceAndConcatVar(o, index)
803+
else:
804+
new_var = self._sliceAndConcatVar(item, 0)
805+
return new_var
806+
630807

631808
def get_all_op_protos():
632809
"""
@@ -744,7 +921,7 @@ def __init__(self,
744921
if _in_imperative_mode():
745922
if type is None:
746923
raise ValueError(
747-
"`type` to initilized an Operator can not be None.")
924+
"`type` to initialized an Operator can not be None.")
748925
self.iop = core.OpBase(type)
749926

750927
# TODO(minqiyang): remove these lines after we take apart all
@@ -906,7 +1083,10 @@ def __str__(self):
9061083

9071084
@property
9081085
def type(self):
909-
return self.desc.type()
1086+
if _in_imperative_mode():
1087+
return self.iop.type
1088+
else:
1089+
return self.desc.type()
9101090

9111091
def input(self, name):
9121092
"""
@@ -1022,6 +1202,9 @@ def _set_attr(self, name, val):
10221202
"""
10231203
self._update_desc_attr(name, val)
10241204

1205+
def _remove_attr(self, name):
1206+
self.desc.remove_attr(name)
1207+
10251208
def _update_desc_attr(self, name, val):
10261209
"""
10271210
Update the value of desc's attribute by attribute's name.
@@ -2515,6 +2698,10 @@ def __init__(self):
25152698
self._trainers_endpoints = []
25162699
# the distributed lookup table names
25172700
self._distributed_lookup_table = None
2701+
2702+
# use Deep gradient comrepssion or not
2703+
self._enable_dgc = False
2704+
25182705
# @deprecated(the python memory optimize transpiler is deprecated)
25192706
# whether the program is optimized by memory_optimize_transpiler
25202707
self.__is_mem_optimized = False
@@ -2565,6 +2752,15 @@ def op_role_var(self):
25652752
def set_op_role_var(self, var_name):
25662753
self._op_role_var = [var_name]
25672754

2755+
@contextlib.contextmanager
2756+
def _backward_role_guard(self):
2757+
tmp_role = self._current_role
2758+
2759+
OpRole = core.op_proto_and_checker_maker.OpRole
2760+
self._current_role = OpRole.Backward
2761+
yield
2762+
self._current_role = tmp_role
2763+
25682764
@signature_safe_contextmanager
25692765
def _optimized_guard(self, param_and_grads):
25702766
"""

0 commit comments

Comments
 (0)