19
19
20
20
21
21
def monkey_patch_variable ():
22
- def new_name ():
22
+ def unique_tmp_name ():
23
23
return unique_name ("tmp" )
24
24
25
25
def safe_get_dtype (var ):
@@ -29,21 +29,9 @@ def safe_get_dtype(var):
29
29
raise ValueError ("Cannot get data type from %s" , var .name )
30
30
return dtype
31
31
32
- def create_scalar (block , value , dtype ):
33
- value = float (value )
34
- tmp_name = new_name ()
35
- var = block .create_var (name = tmp_name , shape = [1 ], dtype = dtype )
36
- block .append_op (
37
- type = "fill" ,
38
- outputs = {"Out" : [var ]},
39
- attrs = {"value" : [value ],
40
- "shape" : [1 ],
41
- "dtype" : dtype })
42
- return var
43
-
44
32
def create_tensor (block , value , dtype , shape ):
45
33
value = float (value )
46
- tmp_name = new_name ()
34
+ tmp_name = unique_tmp_name ()
47
35
var = block .create_var (name = tmp_name , shape = shape , dtype = dtype )
48
36
block .append_op (
49
37
type = "fill_constant" ,
@@ -53,10 +41,13 @@ def create_tensor(block, value, dtype, shape):
53
41
'value' : value })
54
42
return var
55
43
44
+ def create_scalar (block , value , dtype ):
45
+ return create_tensor (block , value , dtype , shape = [1 ])
46
+
56
47
def create_tensor_with_batchsize (ref_var , value , dtype ):
57
48
assert isinstance (ref_var , Variable )
58
49
value = float (value )
59
- tmp_name = new_name ()
50
+ tmp_name = unique_tmp_name ()
60
51
var = ref_var .block .create_var (name = tmp_name , dtype = dtype )
61
52
ref_var .block .append_op (
62
53
type = 'fill_constant_batch_size_like' ,
@@ -68,7 +59,7 @@ def create_tensor_with_batchsize(ref_var, value, dtype):
68
59
69
60
def astype (self , dtype ):
70
61
"""
71
- Cast a variable to data type.
62
+ Cast a variable to a specified data type.
72
63
NOTE: The variable must be a Tensor
73
64
Args:
74
65
self(Variable): The source variable
@@ -77,7 +68,7 @@ def astype(self, dtype):
77
68
Returns:
78
69
Variable with new dtype
79
70
"""
80
- tmp_name = new_name ()
71
+ tmp_name = unique_tmp_name ()
81
72
out = self .block .create_var (name = tmp_name , dtype = dtype )
82
73
self .block .append_op (
83
74
type = "cast" ,
@@ -120,7 +111,7 @@ def __impl__(self, other_var):
120
111
self = other_var
121
112
other_var = tmp
122
113
123
- tmp_name = new_name ()
114
+ tmp_name = unique_tmp_name ()
124
115
out = self .block .create_var (name = tmp_name , dtype = lhs_dtype )
125
116
self .block .append_op (
126
117
type = op_type ,
0 commit comments