Skip to content

Commit 2d84974

Browse files
committed
add arrayfire interop
1 parent 82710f6 commit 2d84974

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed

Etaler/Interop/Arrayfire.hpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#pragma once
2+
3+
#include <arrayfire.h>
4+
5+
#include <Etaler/Core/Tensor.hpp>
6+
7+
namespace et
8+
{
9+
10+
Tensor from_afarray(const af::array& arr, bool transpose=true)
11+
{
12+
et_assert(arr.type() == f32 || arr.type() == s32 || arr.type() == b8);
13+
14+
auto a = arr;
15+
if(transpose) //ArrayFire stores data in fortran order, we might want to transpose it to make it C order
16+
a = af::transpose(arr);
17+
18+
//Convert to et::Shape
19+
Shape s;
20+
auto dims = a.dims();
21+
for(dim_t i=0;i<dims.ndims();i++)
22+
s.push_back(dims[i]);
23+
24+
DType dtype = [](auto type) {
25+
if(type == f32)
26+
return DType::Float;
27+
else if(type == s32)
28+
return DType::Int32;
29+
else
30+
return DType::Bool;
31+
}(arr.type());
32+
33+
//Copy data from AF to Etaler
34+
Tensor res;
35+
if(dtype == DType::Float) {
36+
auto ptr = a.host<float>();
37+
res = Tensor(s, ptr);
38+
af::freeHost(ptr);
39+
}
40+
else if(dtype == DType::Int32) {
41+
auto ptr = a.host<int>();
42+
res = Tensor(s, ptr);
43+
af::freeHost(ptr);
44+
}
45+
else {
46+
auto ptr = a.host<uint8_t>(); //Some arrayfire quarks
47+
res = Tensor(s, ptr);
48+
af::freeHost(ptr);
49+
}
50+
51+
return res;
52+
}
53+
54+
af::array to_afarray(const Tensor& t, bool transpose=true)
55+
{
56+
et_assert(t.dimentions() <= 4);
57+
af::dim4 dims;
58+
//Initalize the dims (not Initalized by default)
59+
for(int i=0;i<4;i++)
60+
dims[i] = 1;
61+
for(size_t i=0;i<t.dimentions();i++)
62+
dims[4-t.dimentions()+i] = t.shape()[i];
63+
64+
af::dtype dtype = [](DType dtype) {
65+
if(dtype == DType::Float)
66+
return f32;
67+
else if(dtype == DType::Int32)
68+
return s32;
69+
else
70+
return b8;
71+
}(t.dtype());
72+
af::array res(dims, dtype);
73+
74+
if(dtype == f32) {
75+
auto v = t.toHost<float>();
76+
res.write(v.data(), v.size()*dtypeToSize(t.dtype()));
77+
}
78+
else if(dtype == s32) {
79+
auto v = t.toHost<int32_t>();
80+
res.write(v.data(), v.size()*dtypeToSize(t.dtype()));
81+
}
82+
else {
83+
auto v = t.toHost<uint8_t>();
84+
res.write(v.data(), v.size()*dtypeToSize(t.dtype()));
85+
}
86+
87+
if(transpose)
88+
return res.T();
89+
return res;
90+
91+
92+
}
93+
94+
}

0 commit comments

Comments
 (0)