Skip to content

Commit 1a618ff

Browse files
authored
Merge pull request #590 from pengcheng888/main
issue/588 - 为Tensor添加from_torch函数, +*运算符重载
2 parents 576b755 + e4b7e5e commit 1a618ff

File tree

4 files changed

+117
-3
lines changed

4 files changed

+117
-3
lines changed

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
empty,
3838
empty_like,
3939
from_blob,
40+
from_torch,
4041
ones,
4142
strided_empty,
4243
strided_from_blob,
@@ -82,6 +83,7 @@
8283
"empty",
8384
"empty_like",
8485
"from_blob",
86+
"from_torch",
8587
"ones",
8688
"strided_empty",
8789
"strided_from_blob",

python/infinicore/tensor.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import infinicore.dtype
33
from infinicore.lib import _infinicore
44

5+
from .utils import to_infinicore_dtype
6+
57

68
class Tensor:
7-
def __init__(self, underlying):
9+
def __init__(self, underlying, *, _torch_ref=None):
810
"""An internal method. Please do not use this directly."""
911

1012
self._underlying = underlying
@@ -15,6 +17,8 @@ def __init__(self, underlying):
1517
self._underlying.device
1618
)
1719

20+
self._torch_ref = _torch_ref
21+
1822
@property
1923
def shape(self):
2024
return self._underlying.shape
@@ -86,6 +90,12 @@ def debug(self, filename=None):
8690
else:
8791
self._underlying.debug(filename)
8892

93+
def __add__(self, other):
94+
return infinicore.add(self, other)
95+
96+
def __mul__(self, other):
97+
return infinicore.mul(self, other)
98+
8999

90100
def empty(size, *, dtype=None, device=None, pin_memory=False):
91101
return Tensor(
@@ -135,3 +145,17 @@ def strided_from_blob(data_ptr, size, strides, *, dtype=None, device=None):
135145
data_ptr, size, strides, dtype._underlying, device._underlying
136146
)
137147
)
148+
149+
150+
def from_torch(torch_tensor) -> Tensor:
151+
infini_type = to_infinicore_dtype(torch_tensor.dtype)
152+
infini_device = infinicore.device(torch_tensor.device.type, 0)
153+
return Tensor(
154+
_infinicore.from_blob(
155+
torch_tensor.data_ptr(),
156+
list(torch_tensor.shape),
157+
dtype=infini_type._underlying,
158+
device=infini_device._underlying,
159+
),
160+
torch_ref=torch_tensor,
161+
)

python/infinicore/utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import torch
2+
3+
import infinicore
4+
5+
6+
def to_torch_dtype(infini_dtype):
7+
"""Convert infinicore data type to PyTorch data type"""
8+
if infini_dtype == infinicore.float16:
9+
return torch.float16
10+
elif infini_dtype == infinicore.float32:
11+
return torch.float32
12+
elif infini_dtype == infinicore.bfloat16:
13+
return torch.bfloat16
14+
elif infini_dtype == infinicore.int8:
15+
return torch.int8
16+
elif infini_dtype == infinicore.int16:
17+
return torch.int16
18+
elif infini_dtype == infinicore.int32:
19+
return torch.int32
20+
elif infini_dtype == infinicore.int64:
21+
return torch.int64
22+
elif infini_dtype == infinicore.uint8:
23+
return torch.uint8
24+
else:
25+
raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}")
26+
27+
28+
def to_infinicore_dtype(torch_dtype):
29+
"""Convert PyTorch data type to infinicore data type"""
30+
if torch_dtype == torch.float32:
31+
return infinicore.float32
32+
elif torch_dtype == torch.float16:
33+
return infinicore.float16
34+
elif torch_dtype == torch.bfloat16:
35+
return infinicore.bfloat16
36+
elif torch_dtype == torch.int8:
37+
return infinicore.int8
38+
elif torch_dtype == torch.int16:
39+
return infinicore.int16
40+
elif torch_dtype == torch.int32:
41+
return infinicore.int32
42+
elif torch_dtype == torch.int64:
43+
return infinicore.int64
44+
elif torch_dtype == torch.uint8:
45+
return infinicore.uint8
46+
else:
47+
raise ValueError(f"Unsupported torch dtype: {torch_dtype}")

test/infinicore/test.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import infinicore
21
import torch
32

3+
import infinicore
4+
45

56
def test():
67
shape = [2, 3, 4]
@@ -40,5 +41,45 @@ def test():
4041
print("Test passed")
4142

4243

44+
def test2():
45+
"测试infinicore.Tensor的from_torch, +* 运算符功能"
46+
shape = [1, 2, 3]
47+
48+
x1_torch = torch.rand(shape, dtype=torch.float32, device="cpu")
49+
x2_torch = torch.rand(shape, dtype=torch.float32, device="cpu")
50+
51+
x1_infini = infinicore.from_torch(x1_torch.clone())
52+
x2_infini = infinicore.from_torch(x2_torch.clone())
53+
54+
ans1_infini = x1_infini + x2_infini
55+
ans2_infini = x1_infini * x2_infini
56+
57+
ans1_torch_ref = x1_torch + x2_torch
58+
ans2_torch_ref = x1_torch * x2_torch
59+
60+
print("----------------------------------------")
61+
torch_ans1_result = torch.zeros(shape, dtype=torch.float32, device="cpu")
62+
torch_ans2_result = torch.zeros(shape, dtype=torch.float32, device="cpu")
63+
torch_ans1 = infinicore.from_blob(
64+
torch_ans1_result.data_ptr(),
65+
shape,
66+
dtype=infinicore.float32,
67+
device=infinicore.device("cpu", 0),
68+
)
69+
torch_ans2 = infinicore.from_blob(
70+
torch_ans2_result.data_ptr(),
71+
shape,
72+
dtype=infinicore.float32,
73+
device=infinicore.device("cpu", 0),
74+
)
75+
torch_ans1.copy_(ans1_infini)
76+
torch_ans2.copy_(ans2_infini)
77+
78+
print("----------------------------------------")
79+
print("abs error: ", torch.abs(ans1_torch_ref - torch_ans1_result).max())
80+
print("abs error: ", torch.abs(ans2_torch_ref - torch_ans2_result).max())
81+
82+
4383
if __name__ == "__main__":
44-
test()
84+
# test()
85+
test2()

0 commit comments

Comments
 (0)