Skip to content

Commit 7a19fff

Browse files
author
pengcheng888
committed
issue/581 - 添加linear的实现,接口和测试
1 parent 2d0a83c commit 7a19fff

File tree

6 files changed

+259
-0
lines changed

6 files changed

+259
-0
lines changed

include/infinicore/ops/linear.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include "common/op.hpp"
4+
#include <pybind11/pybind11.h>
5+
6+
namespace infinicore::op {
7+
8+
Tensor linear(Tensor input, Tensor weight, pybind11::object bias);
9+
10+
void linear_(Tensor out, Tensor input, Tensor weight, pybind11::object bias);
11+
12+
} // namespace infinicore::op

python/infinicore/nn/functional.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,24 @@ def swiglu(input: Tensor, other: Tensor, *, out=None):
6565
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)
6666

6767
return out
68+
69+
70+
def linear(input: Tensor, weight: Tensor, bias=None, *, out=None) -> Tensor:
71+
r"""Applies a linear transformation to the incoming data: y=xA^T+b."""
72+
73+
if out is None:
74+
return Tensor(
75+
_infinicore.linear(
76+
input._underlying,
77+
weight._underlying,
78+
None if bias is None else bias._underlying,
79+
)
80+
)
81+
82+
_infinicore.linear_(
83+
out._underlying,
84+
input._underlying,
85+
weight._underlying,
86+
None if bias is None else bias._underlying,
87+
)
88+
return out
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include "infinicore/ops/linear.hpp"
2+
#include "infinicore/ops/add.hpp"
3+
#include "infinicore/ops/matmul.hpp"
4+
5+
namespace infinicore::op {
6+
7+
Tensor linear(Tensor input,
8+
Tensor weight,
9+
pybind11::object bias) {
10+
Size ndim = input->ndim();
11+
Size out_features = weight->shape()[0];
12+
13+
// Assign memory to out variables
14+
auto output_shape = input->shape();
15+
output_shape[ndim - 1] = out_features;
16+
auto out = Tensor::empty(output_shape, input->dtype(), input->device());
17+
18+
// Inplace Calculate
19+
linear_(out, input, weight, bias);
20+
return out;
21+
}
22+
23+
void linear_(Tensor out,
24+
Tensor input,
25+
Tensor weight,
26+
pybind11::object bias) {
27+
28+
auto weight_shape = weight->shape();
29+
Size out_features = weight_shape[0];
30+
Size in_features = weight_shape[1];
31+
32+
Size ndim = input->ndim();
33+
assert(out->ndim() == ndim);
34+
35+
// Calculate the number of features
36+
Size N = 1;
37+
auto input_shape = input->shape();
38+
for (size_t i = 0; i < ndim - 1; ++i) {
39+
N *= input_shape[i];
40+
}
41+
42+
// linear transformation
43+
Tensor out_view = out->view({N, out_features});
44+
matmul_(out_view,
45+
input->view({N, in_features}),
46+
weight->permute({1, 0}));
47+
48+
// Add bias
49+
if (!bias.is_none()) {
50+
Tensor bias_tensor = bias.cast<Tensor>();
51+
add_(out_view,
52+
out_view,
53+
bias_tensor->as_strided({N, out_features}, {0, 1}));
54+
}
55+
}
56+
57+
} // namespace infinicore::op

src/infinicore/pybind11/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "ops/add.hpp"
66
#include "ops/attention.hpp"
77
#include "ops/causal_softmax.hpp"
8+
#include "ops/linear.hpp"
89
#include "ops/matmul.hpp"
910
#include "ops/rearrange.hpp"
1011
#include "ops/rms_norm.hpp"
@@ -19,6 +20,7 @@ inline void bind(py::module &m) {
1920
bind_add(m);
2021
bind_attention(m);
2122
bind_causal_softmax(m);
23+
bind_linear(m);
2224
bind_matmul(m);
2325
bind_rearrange(m);
2426
bind_rms_norm(m);
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#pragma once
2+
3+
#include "infinicore/ops/linear.hpp"
4+
#include <pybind11/pybind11.h>
5+
6+
namespace py = pybind11;
7+
8+
namespace infinicore::ops {
9+
10+
inline void bind_linear(py::module &m) {
11+
m.def("linear",
12+
&op::linear,
13+
py::arg("input"),
14+
py::arg("weight"),
15+
py::arg("bias") = py::none(),
16+
R"doc(Applies a linear transformation to the incoming data: y=xA^T+b.)doc");
17+
18+
m.def("linear_",
19+
&op::linear_,
20+
py::arg("out"),
21+
py::arg("input"),
22+
py::arg("weight"),
23+
py::arg("bias") = py::none(),
24+
R"doc(In-place, applies a linear transformation to the incoming data: y=xA^T+b.)doc");
25+
}
26+
27+
} // namespace infinicore::ops

test/infinicore/ops/linear.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
5+
6+
import torch
7+
import infinicore
8+
from framework.base import BaseOperatorTest, TensorSpec, TestCase
9+
from framework.runner import GenericTestRunner
10+
from framework.utils import is_broadcast
11+
12+
# ==============================================================================
13+
# Operator-specific configuration
14+
# ==============================================================================
15+
_TEST_CASES_DATA = [
16+
# bs, n, in_features, out_features, bias
17+
(1, 5, 2048, 5632, True, None, None, None),
18+
(1, 1, 2048, 32000, False, None, None, None),
19+
(2, 5, 2048, 5632, True, None, None, None),
20+
(2, 5, 256, 2048, False, None, None, None),
21+
(None, 5, 256, 2048, False, None, None, None),
22+
(None, 1, 2048, 5632, True, None, None, None),
23+
]
24+
25+
# Tolerance configuration
26+
_TOLERANCE_MAP = {
27+
infinicore.float16: {"atol": 0, "rtol": 1e-2},
28+
infinicore.float32: {"atol": 0, "rtol": 1e-3},
29+
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
30+
}
31+
32+
# Data types to test
33+
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
34+
35+
36+
def parse_test_cases():
37+
"""
38+
Parse test case data and return list of TestCase objects for linear operation.
39+
Each test case contains all necessary information for execution and validation.
40+
"""
41+
test_cases = []
42+
43+
for data in _TEST_CASES_DATA:
44+
bs = data[0]
45+
n, in_features, out_features = data[1], data[2], data[3]
46+
bias = data[4]
47+
input_strides = data[5] if len(data) > 5 else None
48+
weight_strides = data[6] if len(data) > 6 else None
49+
out_strides = data[7] if len(data) > 7 else None
50+
51+
# Determine shapes based on batch dimension
52+
if bs is None:
53+
input_shape = (n, in_features)
54+
weight_shape = (out_features, in_features)
55+
out_shape = (n, out_features)
56+
else:
57+
input_shape = (bs, n, in_features)
58+
weight_shape = (out_features, in_features)
59+
out_shape = (bs, n, out_features)
60+
61+
if bias is True:
62+
bias_shape = (out_features,)
63+
else:
64+
bias_shape = None
65+
66+
# Check if tensors support in-place operations
67+
c_supports_inplace = not is_broadcast(out_shape)
68+
69+
# Generate test cases for all data types
70+
for dtype in _TENSOR_DTYPES:
71+
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
72+
73+
# Create typed tensor specs
74+
input_spec = TensorSpec.from_tensor(input_shape, input_strides, dtype)
75+
weight_spec = TensorSpec.from_tensor(weight_shape, weight_strides, dtype)
76+
out_spec = TensorSpec.from_tensor(out_shape, out_strides, dtype)
77+
78+
if bias_shape is not None:
79+
bias_spec = TensorSpec.from_tensor(bias_shape, None, dtype)
80+
else:
81+
bias_spec = None
82+
83+
# Test Case 1: Out-of-place (return value)
84+
test_cases.append(
85+
TestCase(
86+
inputs=[input_spec, weight_spec, bias_spec],
87+
kwargs={},
88+
output_spec=None,
89+
comparison_target=None,
90+
tolerance=tolerance,
91+
description=f"Linear - OUT_OF_PLACE",
92+
)
93+
)
94+
95+
# Test Case 2: In-place with explicit output tensor (Linear(a, b, out=c))
96+
if c_supports_inplace:
97+
test_cases.append(
98+
TestCase(
99+
inputs=[input_spec, weight_spec, bias_spec],
100+
kwargs=None,
101+
output_spec=out_spec, # Specify the output tensor spec
102+
comparison_target="out",
103+
tolerance=tolerance,
104+
description=f"Linear - INPLACE(out)",
105+
)
106+
)
107+
108+
return test_cases
109+
110+
111+
class OpTest(BaseOperatorTest):
112+
"""Linear operator test with simplified implementation"""
113+
114+
def __init__(self):
115+
super().__init__("Linear")
116+
117+
def get_test_cases(self):
118+
return parse_test_cases()
119+
120+
def torch_operator(self, input, weight, bias, out=None, **kwargs):
121+
"""PyTorch linear implementation"""
122+
result = torch.nn.functional.linear(input, weight, bias)
123+
if out is not None:
124+
out.copy_(result)
125+
return out
126+
return result
127+
128+
def infinicore_operator(self, input, weight, bias, out=None, **kwargs):
129+
"""InfiniCore linear implementation"""
130+
return infinicore.nn.functional.linear(input, weight, bias, out=out)
131+
132+
133+
def main():
134+
"""Main entry point"""
135+
runner = GenericTestRunner(OpTest)
136+
runner.run_and_exit()
137+
138+
139+
if __name__ == "__main__":
140+
main()

0 commit comments

Comments
 (0)