Skip to content

Commit 158d5a2

Browse files
committed
[T1-1-1]: Where operator with cpu nvidia metax iluvatar and test
1 parent 7c84868 commit 158d5a2

File tree

11 files changed

+902
-0
lines changed

11 files changed

+902
-0
lines changed

include/infiniop/ops/where.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#ifndef __INFINIOP_WHERE_API_H__
2+
#define __INFINIOP_WHERE_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopWhereDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateWhereDescriptor(infiniopHandle_t handle,
9+
infiniopWhereDescriptor_t *desc_ptr,
10+
infiniopTensorDescriptor_t c,
11+
infiniopTensorDescriptor_t a,
12+
infiniopTensorDescriptor_t b,
13+
infiniopTensorDescriptor_t condition);
14+
15+
__C __export infiniStatus_t infiniopGetWhereWorkspaceSize(infiniopWhereDescriptor_t desc, size_t *size);
16+
17+
__C __export infiniStatus_t infiniopWhere(infiniopWhereDescriptor_t desc,
18+
void *workspace,
19+
size_t workspace_size,
20+
void *c,
21+
const void *a,
22+
const void *b,
23+
const void *condition,
24+
void *stream);
25+
26+
__C __export infiniStatus_t infiniopDestroyWhereDescriptor(infiniopWhereDescriptor_t desc);
27+
28+
#endif
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#include "ops.hpp"
2+
#include "utils.hpp"
3+
#include <infinirt.h>
4+
#include <iomanip>
5+
#include <iostream>
6+
7+
namespace infiniop_test::where {
8+
struct Test::Attributes {
9+
std::shared_ptr<Tensor> a;
10+
std::shared_ptr<Tensor> b;
11+
std::shared_ptr<Tensor> condition;
12+
std::shared_ptr<Tensor> c;
13+
std::shared_ptr<Tensor> ans;
14+
};
15+
16+
std::shared_ptr<Test> Test::build(
17+
std::unordered_map<std::string, std::vector<uint8_t>> attributes,
18+
std::unordered_map<std::string, std::shared_ptr<Tensor>> tensors,
19+
double rtol, double atol) {
20+
auto test = std::shared_ptr<Test>(new Test(rtol, atol));
21+
test->_attributes = new Attributes();
22+
if (tensors.find("a") == tensors.end()
23+
|| tensors.find("b") == tensors.end()
24+
|| tensors.find("condition") == tensors.end()
25+
|| tensors.find("c") == tensors.end()
26+
|| tensors.find("ans") == tensors.end()) {
27+
throw std::runtime_error("Invalid Test");
28+
}
29+
30+
test->_attributes->a = tensors["a"];
31+
test->_attributes->b = tensors["b"];
32+
test->_attributes->condition = tensors["condition"];
33+
test->_attributes->c = tensors["c"];
34+
test->_attributes->ans = tensors["ans"];
35+
36+
auto elemType = test->_attributes->a->ggml_type();
37+
if (elemType == GGML_TYPE_I8) {
38+
test->_rtol = 1e-5;
39+
test->_atol = 1e-5;
40+
}
41+
if (elemType == GGML_TYPE_I16) {
42+
test->_rtol = 1e-5;
43+
test->_atol = 1e-5;
44+
}
45+
if (elemType == GGML_TYPE_I32) {
46+
test->_rtol = 1e-5;
47+
test->_atol = 1e-5;
48+
}
49+
if (elemType == GGML_TYPE_I64) {
50+
test->_rtol = 1e-5;
51+
test->_atol = 1e-5;
52+
}
53+
if (elemType == GGML_TYPE_F16) {
54+
test->_rtol = 1e-7;
55+
test->_atol = 1e-7;
56+
}
57+
if (elemType == GGML_TYPE_F32) {
58+
test->_rtol = 1e-7;
59+
test->_atol = 1e-7;
60+
}
61+
if (elemType == GGML_TYPE_F64) {
62+
test->_rtol = 1e-7;
63+
test->_atol = 1e-7;
64+
}
65+
if (elemType == GGML_TYPE_BF16) {
66+
test->_rtol = 1e-5;
67+
test->_atol = 1e-5;
68+
}
69+
70+
return test;
71+
}
72+
73+
std::shared_ptr<infiniop_test::Result> Test::run(
74+
infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) {
75+
infiniopWhereDescriptor_t op_desc;
76+
auto a = _attributes->a->to(device, device_id);
77+
auto b = _attributes->b->to(device, device_id);
78+
auto condition = _attributes->condition->to(device, device_id);
79+
auto c = _attributes->c->to(device, device_id);
80+
CHECK_OR(infiniopCreateWhereDescriptor(handle, &op_desc,
81+
c->desc(),
82+
a->desc(),
83+
b->desc(),
84+
condition->desc()),
85+
return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor."));
86+
size_t workspace_size;
87+
CHECK_OR(infiniopGetWhereWorkspaceSize(op_desc, &workspace_size),
88+
return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size."));
89+
void *workspace;
90+
CHECK_OR(infinirtMalloc(&workspace, workspace_size),
91+
return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace."));
92+
CHECK_OR(infiniopWhere(op_desc, workspace, workspace_size,
93+
c->data(),
94+
a->data(),
95+
b->data(),
96+
condition->data(),
97+
nullptr),
98+
return TEST_FAILED(OP_EXECUTION_FAILED, "Failed during execution."));
99+
100+
try {
101+
allClose(c, _attributes->ans, _rtol, _atol);
102+
} catch (const std::exception &e) {
103+
return TEST_FAILED(RESULT_INCORRECT, e.what());
104+
}
105+
106+
double elapsed_time = 0.;
107+
108+
elapsed_time = benchmark(
109+
[=]() {
110+
infiniopWhere(
111+
op_desc, workspace, workspace_size,
112+
c->data(),
113+
a->data(),
114+
b->data(),
115+
condition->data(),
116+
nullptr);
117+
},
118+
warm_ups, iterations);
119+
120+
return TEST_PASSED(elapsed_time);
121+
}
122+
123+
std::vector<std::string> Test::attribute_names() {
124+
return {};
125+
}
126+
127+
std::vector<std::string> Test::tensor_names() {
128+
return {"a", "b", "condition", "c", "ans"};
129+
}
130+
131+
std::vector<std::string> Test::output_names() {
132+
return {"c"};
133+
}
134+
135+
std::string Test::toString() const {
136+
std::ostringstream oss;
137+
oss << op_name() << std::endl;
138+
oss << "- a: " << _attributes->a->info() << std::endl;
139+
oss << "- b: " << _attributes->b->info() << std::endl;
140+
oss << "- condition: " << _attributes->condition->info() << std::endl;
141+
oss << "- c: " << _attributes->c->info() << std::endl;
142+
oss << std::scientific << std::setprecision(2);
143+
oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl;
144+
return oss.str();
145+
}
146+
147+
Test::~Test() {
148+
delete _attributes;
149+
}
150+
151+
} // namespace infiniop_test::where
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#include "where_cpu.h"
2+
3+
namespace op::where::cpu {
4+
5+
Descriptor::~Descriptor() = default;
6+
7+
infiniStatus_t Descriptor::create(
8+
infiniopHandle_t handle_,
9+
Descriptor **desc_ptr,
10+
infiniopTensorDescriptor_t out_desc,
11+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
12+
13+
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
14+
auto dtype = out_desc->dtype();
15+
16+
const auto &a_desc = input_desc_vec.at(0);
17+
const auto &b_desc = input_desc_vec.at(1);
18+
const auto &cond_desc = input_desc_vec.at(2);
19+
20+
const auto &c_shape = out_desc->shape();
21+
const auto &a_shape = a_desc->shape();
22+
const auto &b_shape = b_desc->shape();
23+
const auto &cond_shape = cond_desc->shape();
24+
25+
CHECK_DTYPE(cond_desc->dtype(),
26+
INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16,
27+
INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64,
28+
INFINI_DTYPE_U8, INFINI_DTYPE_U16, INFINI_DTYPE_U32, INFINI_DTYPE_U64,
29+
INFINI_DTYPE_BOOL);
30+
31+
CHECK_DTYPE(dtype,
32+
INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16,
33+
INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64,
34+
INFINI_DTYPE_U8, INFINI_DTYPE_U16, INFINI_DTYPE_U32, INFINI_DTYPE_U64,
35+
INFINI_DTYPE_BOOL);
36+
37+
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape, cond_shape);
38+
39+
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);
40+
41+
return INFINI_STATUS_SUCCESS;
42+
}
43+
44+
infiniStatus_t Descriptor::calculate(
45+
void *workspace,
46+
size_t workspace_size,
47+
void *output,
48+
std::vector<const void *> inputs,
49+
void *stream) const {
50+
51+
switch (_dtype) {
52+
case INFINI_DTYPE_F16:
53+
return _device_info->calculate<WhereOp, fp16_t>(_info, output, inputs, stream);
54+
case INFINI_DTYPE_BF16:
55+
return _device_info->calculate<WhereOp, bf16_t>(_info, output, inputs, stream);
56+
case INFINI_DTYPE_F32:
57+
return _device_info->calculate<WhereOp, float>(_info, output, inputs, stream);
58+
case INFINI_DTYPE_F64:
59+
return _device_info->calculate<WhereOp, double>(_info, output, inputs, stream);
60+
case INFINI_DTYPE_I8:
61+
return _device_info->calculate<WhereOp, int8_t>(_info, output, inputs, stream);
62+
case INFINI_DTYPE_I16:
63+
return _device_info->calculate<WhereOp, int16_t>(_info, output, inputs, stream);
64+
case INFINI_DTYPE_I32:
65+
return _device_info->calculate<WhereOp, int32_t>(_info, output, inputs, stream);
66+
case INFINI_DTYPE_I64:
67+
return _device_info->calculate<WhereOp, int64_t>(_info, output, inputs, stream);
68+
case INFINI_DTYPE_U8:
69+
return _device_info->calculate<WhereOp, uint8_t>(_info, output, inputs, stream);
70+
case INFINI_DTYPE_U16:
71+
return _device_info->calculate<WhereOp, uint16_t>(_info, output, inputs, stream);
72+
case INFINI_DTYPE_U32:
73+
return _device_info->calculate<WhereOp, uint32_t>(_info, output, inputs, stream);
74+
case INFINI_DTYPE_U64:
75+
return _device_info->calculate<WhereOp, uint64_t>(_info, output, inputs, stream);
76+
case INFINI_DTYPE_BOOL:
77+
return _device_info->calculate<WhereOp, bool>(_info, output, inputs, stream);
78+
default:
79+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
80+
}
81+
82+
return INFINI_STATUS_SUCCESS;
83+
}
84+
} // namespace op::where::cpu
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef __WHERE_CPU_H__
2+
#define __WHERE_CPU_H__
3+
4+
#include "../../../elementwise/cpu/elementwise_cpu.h"
5+
6+
ELEMENTWISE_DESCRIPTOR(where, cpu)
7+
8+
namespace op::where::cpu {
9+
typedef struct WhereOp {
10+
public:
11+
static constexpr size_t num_inputs = 3;
12+
template <typename T>
13+
T operator()(const T &a, const T &b, const T &cond) const {
14+
return cond ? a : b;
15+
}
16+
} WhereOp;
17+
} // namespace op::where::cpu
18+
19+
#endif // __WHERE_CPU_H__
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef __WHERE_CUDA_H__
2+
#define __WHERE_CUDA_H__
3+
4+
namespace op::where::cuda {
5+
typedef struct WhereOp {
6+
public:
7+
static constexpr size_t num_inputs = 3;
8+
template <typename T>
9+
__device__ __forceinline__ T operator()(const T &a, const T &b, const T &cond) const {
10+
return cond ? a : b;
11+
}
12+
} WhereOp;
13+
} // namespace op::where::cuda
14+
15+
#endif // __WHERE_CUDA_H__
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __WHERE_METAX_API_H__
2+
#define __WHERE_METAX_API_H__
3+
4+
#include "../../../elementwise/metax/elementwise_metax_api.h"
5+
6+
ELEMENTWISE_DESCRIPTOR(where, metax)
7+
8+
#endif // __WHERE_METAX_API_H__
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#include "where_metax.h"
2+
3+
#include "../../../elementwise/metax/elementwise_metax.h"
4+
5+
#include "../cuda/kernel.cuh"
6+
7+
namespace op::where::metax {
8+
9+
Descriptor::~Descriptor() = default;
10+
11+
infiniStatus_t Descriptor::create(
12+
infiniopHandle_t handle_,
13+
Descriptor **desc_ptr,
14+
infiniopTensorDescriptor_t out_desc,
15+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
16+
17+
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
18+
auto dtype = out_desc->dtype();
19+
20+
const auto &a_desc = input_desc_vec.at(0);
21+
const auto &b_desc = input_desc_vec.at(1);
22+
const auto &c_shape = out_desc->shape();
23+
const auto &a_shape = a_desc->shape();
24+
const auto &b_shape = b_desc->shape();
25+
26+
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
27+
28+
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
29+
30+
// create CUDA elementwise descriptor
31+
CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
32+
33+
return INFINI_STATUS_SUCCESS;
34+
}
35+
36+
infiniStatus_t Descriptor::calculate(
37+
void *workspace,
38+
size_t workspace_size,
39+
void *output,
40+
std::vector<const void *> inputs,
41+
void *stream) const {
42+
43+
if (workspace_size < _workspace_size) {
44+
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
45+
}
46+
47+
switch (_dtype) {
48+
case INFINI_DTYPE_F16:
49+
return _device_info->calculate<256, cuda::WhereOp, half>(_info, workspace, output, inputs, stream);
50+
case INFINI_DTYPE_BF16:
51+
return _device_info->calculate<256, cuda::WhereOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
52+
case INFINI_DTYPE_F32:
53+
return _device_info->calculate<256, cuda::WhereOp, float>(_info, workspace, output, inputs, stream);
54+
case INFINI_DTYPE_F64:
55+
return _device_info->calculate<256, cuda::WhereOp, double>(_info, workspace, output, inputs, stream);
56+
default:
57+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
58+
}
59+
60+
return INFINI_STATUS_SUCCESS;
61+
}
62+
} // namespace op::where::metax

0 commit comments

Comments
 (0)