Skip to content

Commit 2286cf7

Browse files
Merge pull request #548 from gongchensu/feature/add_mul_python_api
Feature/add mul python api
2 parents 2d0a83c + a565b36 commit 2286cf7

File tree

8 files changed

+273
-0
lines changed

8 files changed

+273
-0
lines changed

include/infinicore/ops/mul.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
class Mul {
8+
public:
9+
using schema = void (*)(Tensor, Tensor, Tensor);
10+
static void execute(Tensor c, Tensor a, Tensor b);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
14+
Tensor mul(Tensor a, Tensor b);
15+
void mul_(Tensor c, Tensor a, Tensor b);
16+
} // namespace infinicore::op

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from infinicore.ops.add import add
3131
from infinicore.ops.attention import attention
3232
from infinicore.ops.matmul import matmul
33+
from infinicore.ops.mul import mul
3334
from infinicore.ops.rearrange import rearrange
3435
from infinicore.tensor import (
3536
Tensor,
@@ -76,6 +77,7 @@
7677
"add",
7778
"attention",
7879
"matmul",
80+
"mul",
7981
"rearrange",
8082
"empty",
8183
"empty_like",

python/infinicore/ops/mul.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def mul(input, other, *, out=None):
6+
if out is None:
7+
return Tensor(_infinicore.mul(input._underlying, other._underlying))
8+
9+
_infinicore.mul_(out._underlying, input._underlying, other._underlying)

src/infinicore/ops/mul/mul.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include "infinicore/ops/mul.hpp"
2+
3+
namespace infinicore::op {
4+
5+
common::OpDispatcher<Mul::schema> &Mul::dispatcher() {
6+
static common::OpDispatcher<Mul::schema> dispatcher_;
7+
return dispatcher_;
8+
};
9+
10+
void Mul::execute(Tensor c, Tensor a, Tensor b) {
11+
dispatcher().lookup(context::getDevice().getType())(c, a, b);
12+
}
13+
14+
Tensor mul(Tensor a, Tensor b) {
15+
auto c = Tensor::empty(a->shape(), a->dtype(), a->device());
16+
mul_(c, a, b);
17+
return c;
18+
}
19+
20+
void mul_(Tensor c, Tensor a, Tensor b) {
21+
Mul::execute(c, a, b);
22+
}
23+
24+
} // namespace infinicore::op
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#include "../../utils.hpp"
2+
#include "infinicore/common/hash.hpp"
3+
#include "infinicore/ops/common/cache.hpp"
4+
#include "infinicore/ops/mul.hpp"
5+
#include <infiniop.h>
6+
7+
namespace infinicore::op::mul_impl::infiniop {
8+
9+
thread_local common::OpCache<size_t, infiniopMulDescriptor_t> caches(
10+
100, // capacity
11+
[](infiniopMulDescriptor_t &desc) {
12+
if (desc != nullptr) {
13+
INFINICORE_CHECK_ERROR(infiniopDestroyMulDescriptor(desc));
14+
desc = nullptr;
15+
}
16+
});
17+
18+
void calculate(Tensor c, Tensor a, Tensor b) {
19+
size_t seed = hash_combine(c, b, a);
20+
21+
auto device_type = context::getDevice().getType();
22+
auto device_index = context::getDevice().getIndex();
23+
24+
auto &cache = caches.getCache(device_type, device_index);
25+
26+
auto desc_opt = cache.get(seed);
27+
infiniopMulDescriptor_t desc = nullptr;
28+
29+
if (!desc_opt) {
30+
INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor(
31+
context::getInfiniopHandle(), &desc,
32+
c->desc(), a->desc(), b->desc()));
33+
cache.put(seed, desc);
34+
} else {
35+
desc = *desc_opt;
36+
}
37+
38+
size_t workspace_size = 0;
39+
INFINICORE_CHECK_ERROR(infiniopGetMulWorkspaceSize(desc, &workspace_size));
40+
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
41+
42+
INFINICORE_CHECK_ERROR(infiniopMul(
43+
desc, workspace->data(), workspace_size,
44+
c->data(), a->data(), b->data(), context::getStream()));
45+
}
46+
47+
static bool registered = []() {
48+
Mul::dispatcher().registerAll(&calculate, false);
49+
return true;
50+
}();
51+
52+
} // namespace infinicore::op::mul_impl::infiniop

src/infinicore/pybind11/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "ops/attention.hpp"
77
#include "ops/causal_softmax.hpp"
88
#include "ops/matmul.hpp"
9+
#include "ops/mul.hpp"
910
#include "ops/rearrange.hpp"
1011
#include "ops/rms_norm.hpp"
1112
#include "ops/silu.hpp"
@@ -20,6 +21,7 @@ inline void bind(py::module &m) {
2021
bind_attention(m);
2122
bind_causal_softmax(m);
2223
bind_matmul(m);
24+
bind_mul(m);
2325
bind_rearrange(m);
2426
bind_rms_norm(m);
2527
bind_silu(m);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
#include <pybind11/pybind11.h>
4+
5+
#include "infinicore/ops/mul.hpp"
6+
7+
namespace py = pybind11;
8+
9+
namespace infinicore::ops {
10+
11+
inline void bind_mul(py::module &m) {
12+
m.def("mul",
13+
&op::mul,
14+
py::arg("a"),
15+
py::arg("b"),
16+
R"doc(Element-wise multiplication of two tensors.)doc");
17+
18+
m.def("mul_",
19+
&op::mul_,
20+
py::arg("c"),
21+
py::arg("a"),
22+
py::arg("b"),
23+
R"doc(In-place element-wise tensor multiplication.)doc");
24+
}
25+
26+
} // namespace infinicore::ops

test/infinicore/ops/mul.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
5+
6+
import torch
7+
import infinicore
8+
from framework.base import BaseOperatorTest, TensorSpec, TestCase
9+
from framework.runner import GenericTestRunner
10+
from framework.utils import is_broadcast
11+
12+
# ==============================================================================
13+
# Operator-specific configuration
14+
# ==============================================================================
15+
16+
# Test cases format: (shape, a_strides, b_strides, c_strides)
17+
_TEST_CASES_DATA = [
18+
((13, 4), None, None, None),
19+
((13, 4), (10, 1), (10, 1), (10, 1)),
20+
((13, 4), (0, 1), None, None),
21+
((13, 4, 4), None, None, None),
22+
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
23+
((13, 4, 4), (4, 0, 1), (0, 4, 1), None),
24+
((16, 5632), None, None, None),
25+
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
26+
]
27+
28+
# Data types
29+
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
30+
31+
# Tolerance
32+
_TOLERANCE_MAP = {
33+
infinicore.float16: {"atol": 0, "rtol": 1e-2},
34+
infinicore.float32: {"atol": 0, "rtol": 1e-3},
35+
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
36+
}
37+
38+
39+
def build_test_cases():
40+
test_cases = []
41+
42+
for data in _TEST_CASES_DATA:
43+
shape = data[0]
44+
a_strides = data[1] if len(data) > 1 else None
45+
b_strides = data[2] if len(data) > 2 else None
46+
c_strides = data[3] if len(data) > 3 else None
47+
48+
a_supports_inplace = not is_broadcast(a_strides)
49+
b_supports_inplace = not is_broadcast(b_strides)
50+
c_supports_inplace = not is_broadcast(c_strides)
51+
52+
for dtype in _TENSOR_DTYPES:
53+
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
54+
55+
a_spec = TensorSpec.from_tensor(shape, a_strides, dtype)
56+
b_spec = TensorSpec.from_tensor(shape, b_strides, dtype)
57+
c_spec = TensorSpec.from_tensor(shape, c_strides, dtype)
58+
59+
# Out-of-place (return value)
60+
test_cases.append(
61+
TestCase(
62+
inputs=[a_spec, b_spec],
63+
kwargs={},
64+
output_spec=None,
65+
comparison_target=None,
66+
tolerance=tolerance,
67+
description=f"Mul - OUT_OF_PLACE (dtype={dtype})",
68+
)
69+
)
70+
71+
# With explicit output tensor (mul(a, b, out=c))
72+
if c_supports_inplace:
73+
test_cases.append(
74+
TestCase(
75+
inputs=[a_spec, b_spec],
76+
kwargs={},
77+
output_spec=c_spec,
78+
comparison_target="out",
79+
tolerance=tolerance,
80+
description=f"Mul - INPLACE(out) (dtype={dtype})",
81+
)
82+
)
83+
84+
# In-place on first input (mul(a, b, out=a))
85+
if a_supports_inplace:
86+
test_cases.append(
87+
TestCase(
88+
inputs=[a_spec, b_spec],
89+
kwargs={"out": 0},
90+
output_spec=None,
91+
comparison_target=0,
92+
tolerance=tolerance,
93+
description=f"Mul - INPLACE(a) (dtype={dtype})",
94+
)
95+
)
96+
97+
# In-place on second input (mul(a, b, out=b))
98+
if b_supports_inplace:
99+
test_cases.append(
100+
TestCase(
101+
inputs=[a_spec, b_spec],
102+
kwargs={"out": 1},
103+
output_spec=None,
104+
comparison_target=1,
105+
tolerance=tolerance,
106+
description=f"Mul - INPLACE(b) (dtype={dtype})",
107+
)
108+
)
109+
110+
return test_cases
111+
112+
113+
_TEST_CASES = build_test_cases()
114+
115+
116+
class OpTest(BaseOperatorTest):
117+
"""Mul test with simplified test case parsing"""
118+
119+
def __init__(self):
120+
super().__init__("Mul")
121+
122+
def get_test_cases(self):
123+
return _TEST_CASES
124+
125+
def torch_operator(self, a, b, out=None, **kwargs):
126+
return torch.mul(a, b, out=out)
127+
128+
def infinicore_operator(self, a, b, out=None, **kwargs):
129+
try:
130+
return infinicore.mul(a, b, out=out)
131+
except AttributeError as exc:
132+
raise NotImplementedError("InfiniCore mul operator not available") from exc
133+
134+
135+
def main():
136+
"""Main entry point"""
137+
runner = GenericTestRunner(OpTest)
138+
runner.run_and_exit()
139+
140+
141+
if __name__ == "__main__":
142+
main()

0 commit comments

Comments
 (0)