Skip to content

Commit 86d8836

Browse files
Difersmaxiaolong001
authored andcommitted
add type_as (PaddlePaddle#74459)
1 parent 97725aa commit 86d8836

File tree

4 files changed

+165
-0
lines changed

4 files changed

+165
-0
lines changed

python/paddle/base/dygraph/math_op_patch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ def conversion_method(self: Tensor) -> Tensor:
172172

173173
return methods
174174

175+
def type_as(self: Tensor, other: Tensor) -> Tensor:
176+
return self.astype(other.dtype)
177+
175178
def _scalar_elementwise_op_(
176179
var: Tensor, scale: float, bias: float
177180
) -> Tensor:
@@ -295,6 +298,7 @@ def _mT_(var: Tensor) -> Tensor:
295298
('astype', astype),
296299
('byte', byte),
297300
('uint8', byte),
301+
('type_as', type_as),
298302
('dim', dim),
299303
('ndimension', ndimension),
300304
('ndim', _ndim),

python/paddle/base/layers/math_op_patch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,9 @@ def astype(self, dtype):
382382
out.stop_gradient = self.stop_gradient
383383
return out
384384

385+
def type_as(self, other):
386+
return self.astype(other.dtype)
387+
385388
@static_only
386389
def append(self, var):
387390
"""
@@ -799,6 +802,7 @@ def to_dense(var):
799802
('__neg__', _neg_),
800803
('__abs__', _abs_),
801804
('astype', astype),
805+
('type_as', type_as),
802806
('cpu', cpu),
803807
('cuda', cuda),
804808
('place', place),

python/paddle/pir/math_op_patch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,9 @@ def conversion_method(self):
434434
methods.append((method_name, method_impl))
435435
return methods
436436

437+
def type_as(self, other):
438+
return self.astype(other.dtype)
439+
437440
def _scalar_add_(var, value):
438441
return paddle.scale(var, 1.0, value)
439442

@@ -1175,6 +1178,7 @@ def register_hook(self, hook):
11751178
('astype', astype),
11761179
('byte', byte),
11771180
('uint8', byte),
1181+
('type_as', type_as),
11781182
('size', _size_),
11791183
('T', _T_),
11801184
('mT', _mT_),

test/legacy_test/test_type_as.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
19+
import paddle
20+
from paddle import base
21+
22+
23+
def api_warpprt(x, y):
24+
return x.type_as(y)
25+
26+
27+
class TestTypeAsBase(unittest.TestCase):
28+
def setUp(self):
29+
self.input_dtype_1 = "float32"
30+
self.input_dtype_2 = "float16"
31+
self.input_shape = (2, 3)
32+
33+
self.input_np_1 = self.generate_data(
34+
self.input_dtype_1, self.input_shape
35+
)
36+
self.input_np_2 = self.generate_data(
37+
self.input_dtype_2, self.input_shape
38+
)
39+
40+
self.input_shape_1 = self.input_np_1.shape
41+
self.input_shape_2 = self.input_np_2.shape
42+
43+
self.op_static = api_warpprt
44+
self.op_dygraph = api_warpprt
45+
self.places = [None, paddle.CPUPlace()]
46+
47+
def generate_data(self, dtype, shape):
48+
if "int" in dtype:
49+
data = np.arange(1, np.prod(shape) + 1).reshape(shape)
50+
else:
51+
data = np.arange(1, np.prod(shape) + 1, dtype='float32').reshape(
52+
shape
53+
)
54+
return data.astype(dtype)
55+
56+
def check_static_result(self, place):
57+
paddle.enable_static()
58+
main_prog = paddle.static.Program()
59+
startup_prog = paddle.static.Program()
60+
with paddle.static.program_guard(main_prog, startup_prog):
61+
input_name_1 = 'input_1'
62+
input_name_2 = 'input_2'
63+
input_var_1 = paddle.static.data(
64+
name=input_name_1,
65+
shape=self.input_shape_1,
66+
dtype=self.input_dtype_1,
67+
)
68+
input_var_2 = paddle.static.data(
69+
name=input_name_2,
70+
shape=self.input_shape_2,
71+
dtype=self.input_dtype_2,
72+
)
73+
res = self.op_static(input_var_1, input_var_2)
74+
exe = base.Executor(place)
75+
fetches = exe.run(
76+
main_prog,
77+
feed={
78+
input_name_1: self.input_np_1,
79+
input_name_2: self.input_np_2,
80+
},
81+
fetch_list=[res],
82+
)
83+
self.assertEqual(fetches[0].dtype, np.dtype(self.input_dtype_2))
84+
85+
def test_static(self):
86+
for place in self.places:
87+
self.check_static_result(place=place)
88+
89+
def check_dygraph_result(self, place):
90+
with base.dygraph.guard(place):
91+
input_1 = paddle.to_tensor(self.input_np_1)
92+
input_2 = paddle.to_tensor(self.input_np_2)
93+
result = self.op_dygraph(input_1, input_2)
94+
self.assertEqual(result.dtype, input_2.dtype)
95+
96+
def test_dygraph(self):
97+
for place in self.places:
98+
self.check_dygraph_result(place=place)
99+
100+
101+
class TestTypeAsFloat32ToFloat16(TestTypeAsBase):
102+
def setUp(self):
103+
self.input_dtype_1 = "float32"
104+
self.input_dtype_2 = "float16"
105+
super().setUp()
106+
107+
108+
class TestTypeAsFloat64ToFloat32(TestTypeAsBase):
109+
def setUp(self):
110+
self.input_dtype_1 = "float64"
111+
self.input_dtype_2 = "float32"
112+
super().setUp()
113+
114+
115+
class TestTypeAsInt32ToInt64(TestTypeAsBase):
116+
def setUp(self):
117+
self.input_dtype_1 = "int32"
118+
self.input_dtype_2 = "int64"
119+
super().setUp()
120+
121+
122+
class TestTypeAsInt32ToFloat32(TestTypeAsBase):
123+
def setUp(self):
124+
self.input_dtype_1 = "int32"
125+
self.input_dtype_2 = "float32"
126+
super().setUp()
127+
128+
129+
class TestTypeAsFloat32ToInt64(TestTypeAsBase):
130+
def setUp(self):
131+
self.input_dtype_1 = "float32"
132+
self.input_dtype_2 = "int64"
133+
super().setUp()
134+
135+
136+
class TestTypeAsInt8ToFloat64(TestTypeAsBase):
137+
def setUp(self):
138+
self.input_dtype_1 = "int8"
139+
self.input_dtype_2 = "float64"
140+
self.input_shape = (4, 2)
141+
super().setUp()
142+
143+
144+
class TestTypeAsUInt8ToInt32(TestTypeAsBase):
145+
def setUp(self):
146+
self.input_dtype_1 = "uint8"
147+
self.input_dtype_2 = "int32"
148+
self.input_shape = (3, 3)
149+
super().setUp()
150+
151+
152+
if __name__ == "__main__":
153+
unittest.main()

0 commit comments

Comments
 (0)