@@ -627,6 +627,183 @@ def _set_error_clip(self, error_clip):
627
627
"""
628
628
self .error_clip = error_clip
629
629
630
+ def _slice_indices (self , slice , length ):
631
+ """
632
+ Reference implementation for the slice.indices method.
633
+ """
634
+ # Compute step and length as integers.
635
+ step = 1 if slice .step is None else slice .step
636
+
637
+ # Raise ValueError for negative length or zero step.
638
+ if length < 0 :
639
+ raise ValueError ("length should not be negative" )
640
+ if step == 0 :
641
+ raise ValueError ("slice step cannot be zero" )
642
+
643
+ # Find lower and upper bounds for start and stop.
644
+ lower = - 1 if step < 0 else 0
645
+ upper = length - 1 if step < 0 else length
646
+
647
+ # Compute start.
648
+ if slice .start is None :
649
+ start = upper if step < 0 else lower
650
+ else :
651
+ start = slice .start
652
+ start = max (start + length , lower ) if start < 0 else min (start ,
653
+ upper )
654
+
655
+ # Compute stop.
656
+ if slice .stop is None :
657
+ stop = lower if step < 0 else upper
658
+ else :
659
+ stop = slice .stop
660
+ stop = max (stop + length , lower ) if stop < 0 else min (stop , upper )
661
+
662
+ return start , stop , step
663
+
664
+ def _detectEllipsis (self , item ):
665
+ has_ellipsis = False
666
+ start = 0
667
+ end = len (self .shape )
668
+ for index , o in enumerate (item ):
669
+ if o is Ellipsis :
670
+ if has_ellipsis :
671
+ raise ValueError ("Index can have one ellipsis only." )
672
+ has_ellipsis = True
673
+ start = index
674
+ else :
675
+ if has_ellipsis :
676
+ end = index
677
+ return has_ellipsis , start , end
678
+
679
+ def _reconstructSliceinfo (self , item ):
680
+ has_ellipsis , start , end = self ._detectEllipsis (item )
681
+ if has_ellipsis :
682
+ newitem = []
683
+ for i in range (start ):
684
+ newitem .append (item [i ])
685
+ for i in range (start , end ):
686
+ newitem .append (slice (None , None , None ))
687
+ for i in range (end , len (item )):
688
+ newitem .append (item [i ])
689
+ return newitem
690
+ else :
691
+ return None
692
+
693
+ def _detectContinuesSlice (self , item ):
694
+ starts = []
695
+ ends = []
696
+ for index , o in enumerate (item ):
697
+ if isinstance (o , int ):
698
+ start = int (o )
699
+ if (index > 0 and index >= self .shape [index ]) \
700
+ or (index < 0 and (index + self .shape [index ]) < 0 ):
701
+ raise IndexError ("invalid index" )
702
+ start = max (start + self .shape [index ], 0 ) if start < 0 else min (
703
+ start , self .shape [index ])
704
+ starts .append (start )
705
+ ends .append (start + 1 )
706
+ elif isinstance (o , slice ):
707
+ start , stop , step = self ._slice_indices (o , self .shape [index ])
708
+ if step == 1 or step == - 1 :
709
+ starts .append (start )
710
+ ends .append (stop )
711
+ else :
712
+ return False , None
713
+ else :
714
+ raise IndexError ("Valid index accept int or slice or ellipsis" )
715
+ return True , [starts , ends ]
716
+
717
+ def _cloneVar (self , copy = False ):
718
+ if not copy :
719
+ return self .block .create_var (
720
+ name = unique_name .generate ("." .join (self .name )),
721
+ dtype = self .dtype ,
722
+ persistable = self .persistable ,
723
+ stop_gradient = self ._stop_gradient , )
724
+ else :
725
+ return self
726
+
727
+ def _sliceVar (self , axes , starts , ends ):
728
+ new_var = self ._cloneVar ()
729
+ self .block .append_op (
730
+ type = "slice" ,
731
+ inputs = {'Input' : [self ]},
732
+ outputs = {'Out' : [new_var ]},
733
+ attrs = {'axes' : axes ,
734
+ 'starts' : starts ,
735
+ 'ends' : ends })
736
+ return new_var
737
+
738
+ def _concatVar (self , inputs , axis ):
739
+ new_var = self ._cloneVar ()
740
+ self .block .append_op (
741
+ type = "concat" ,
742
+ inputs = {'X' : inputs },
743
+ outputs = {'Out' : [new_var ]},
744
+ attrs = {'axis' : axis , })
745
+ return new_var
746
+
747
+ def _sliceAndConcatVar (self , item , axis ):
748
+ if isinstance (item , slice ):
749
+ if self .shape [axis ] < 0 :
750
+ return self ._cloneVar (True )
751
+ start , stop , step = self ._slice_indices (item , self .shape [axis ])
752
+ if step == 1 :
753
+ return self ._sliceVar ([axis ], [start ], [stop ])
754
+ else :
755
+ vars = []
756
+ if step > 0 :
757
+ while start < stop :
758
+ vars .append (
759
+ self ._sliceVar ([axis ], [start ], [start + 1 ]))
760
+ start += step
761
+ else :
762
+ while start > stop :
763
+ vars .append (
764
+ self ._sliceVar ([axis ], [start ], [start + 1 ]))
765
+ start += step
766
+ return self ._concatVar (vars , axis )
767
+ elif isinstance (item , int ):
768
+ if self .shape [axis ] < 0 :
769
+ return self ._cloneVar (True )
770
+ index = int (item )
771
+ if (index > 0 and index >= self .shape [axis ])\
772
+ or (index < 0 and (index + self .shape [axis ]) < 0 ):
773
+ raise IndexError ("invalid index" )
774
+ return self ._sliceVar ([axis ], [index ], [index + 1 ])
775
+ else :
776
+ raise IndexError ("Valid index accept int or slice or tuple" )
777
+
778
+ def __getitem__ (self , item ):
779
+ """
780
+ Slice the variable.
781
+
782
+ Args:
783
+ item(int/slice/tuple) : the index.
784
+
785
+ Returns:
786
+ Sliced variable
787
+ """
788
+ new_var = None
789
+ if isinstance (item , tuple ):
790
+ if len (item ) > len (self .shape ):
791
+ raise IndexError ("Too many indexes" )
792
+ newitem = self ._reconstructSliceinfo (item ) or item
793
+ check , info = self ._detectContinuesSlice (newitem )
794
+ if check :
795
+ starts = info [0 ]
796
+ ends = info [1 ]
797
+ axes = [i for i in range (len (starts ))]
798
+ return self ._sliceVar (axes , starts , ends )
799
+ else :
800
+ new_var = self
801
+ for index , o in enumerate (newitem ):
802
+ new_var = new_var ._sliceAndConcatVar (o , index )
803
+ else :
804
+ new_var = self ._sliceAndConcatVar (item , 0 )
805
+ return new_var
806
+
630
807
631
808
def get_all_op_protos ():
632
809
"""
@@ -744,7 +921,7 @@ def __init__(self,
744
921
if _in_imperative_mode ():
745
922
if type is None :
746
923
raise ValueError (
747
- "`type` to initilized an Operator can not be None." )
924
+ "`type` to initialized an Operator can not be None." )
748
925
self .iop = core .OpBase (type )
749
926
750
927
# TODO(minqiyang): remove these lines after we take apart all
@@ -906,7 +1083,10 @@ def __str__(self):
906
1083
907
1084
@property
908
1085
def type (self ):
909
- return self .desc .type ()
1086
+ if _in_imperative_mode ():
1087
+ return self .iop .type
1088
+ else :
1089
+ return self .desc .type ()
910
1090
911
1091
def input (self , name ):
912
1092
"""
@@ -1022,6 +1202,9 @@ def _set_attr(self, name, val):
1022
1202
"""
1023
1203
self ._update_desc_attr (name , val )
1024
1204
1205
+ def _remove_attr (self , name ):
1206
+ self .desc .remove_attr (name )
1207
+
1025
1208
def _update_desc_attr (self , name , val ):
1026
1209
"""
1027
1210
Update the value of desc's attribute by attribute's name.
@@ -2515,6 +2698,10 @@ def __init__(self):
2515
2698
self ._trainers_endpoints = []
2516
2699
# the distributed lookup table names
2517
2700
self ._distributed_lookup_table = None
2701
+
2702
+ # use Deep gradient comrepssion or not
2703
+ self ._enable_dgc = False
2704
+
2518
2705
# @deprecated(the python memory optimize transpiler is deprecated)
2519
2706
# whether the program is optimized by memory_optimize_transpiler
2520
2707
self .__is_mem_optimized = False
@@ -2565,6 +2752,15 @@ def op_role_var(self):
2565
2752
def set_op_role_var (self , var_name ):
2566
2753
self ._op_role_var = [var_name ]
2567
2754
2755
+ @contextlib .contextmanager
2756
+ def _backward_role_guard (self ):
2757
+ tmp_role = self ._current_role
2758
+
2759
+ OpRole = core .op_proto_and_checker_maker .OpRole
2760
+ self ._current_role = OpRole .Backward
2761
+ yield
2762
+ self ._current_role = tmp_role
2763
+
2568
2764
@signature_safe_contextmanager
2569
2765
def _optimized_guard (self , param_and_grads ):
2570
2766
"""
0 commit comments