@@ -46,12 +46,12 @@ class ErrorClipByValue(BaseErrorClipAttr):
46
46
self .min = min
47
47
48
48
def append_clip_op (self , block , grad_name ):
49
- block.append_op(
50
- type = " clip" ,
51
- inputs = { " X" : grad_name},
52
- outputs = { " Out" : grad_name},
53
- attrs = { " min" : self .min,
54
- " max" : self .max} )
49
+ clip_op_desc = block.desc. append_op()
50
+ clip_op_desc.set_type( " clip" )
51
+ clip_op_desc.set_input( " X" , [ grad_name])
52
+ clip_op_desc.set_output( " Out" , [ grad_name])
53
+ clip_op_desc.set_attr( " min" , self .min)
54
+ clip_op_desc.set_attr( " max" , self .max)
55
55
```
56
56
57
57
The ` BaseErrorClipAttr ` have one main member functions: ` append_clip_op(self, block, grad_name) ` .
@@ -80,6 +80,11 @@ def error_clip_callback(block, context):
80
80
op_desc.output_arg_names()):
81
81
fwd_var = block.var_recursive(grad_to_var[grad_n])
82
82
error_clip = getattr (fwd_var, " error_clip" , None )
83
+ if not (error_clip is None or isinstance (error_clip,
84
+ BaseErrorClipAttr)):
85
+ raise TypeError (
86
+ " Variable's error_clip should be an instance of BaseErrorClipAttr or None."
87
+ )
83
88
if error_clip is not None :
84
89
error_clip.append_clip_op(block, grad_n)
85
90
```
0 commit comments