Skip to content

Commit fc2f2a0

Browse files
author
pengcheng888
committed
issue/581 - 添加linear的实现,接口和测试
1 parent 2d0a83c commit fc2f2a0

File tree

9 files changed

+296
-70
lines changed

9 files changed

+296
-70
lines changed

include/infinicore/ops/linear.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include "common/op.hpp"
4+
5+
namespace infinicore::op {
6+
7+
Tensor linear(Tensor input, Tensor weight, std::optional<Tensor> bias);
8+
9+
void linear_(Tensor out, Tensor input, Tensor weight, std::optional<Tensor> bias);
10+
11+
} // namespace infinicore::op

python/infinicore/nn/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from infinicore.nn import (
2-
functional as functional,
3-
)
1+
from infinicore.nn import functional
2+
3+
__all__ = ["functional"]

python/infinicore/nn/functional.py

Lines changed: 0 additions & 67 deletions
This file was deleted.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .linear import linear
2+
3+
__all__ = ["linear"]
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
__all__ = ["linear"]
5+
6+
7+
def linear(input: Tensor, weight: Tensor, bias=None, *, out=None) -> Tensor:
8+
r"""Applies a linear transformation to the incoming data: y=xA^T+b."""
9+
10+
if out is None:
11+
return Tensor(
12+
_infinicore.linear(
13+
input._underlying,
14+
weight._underlying,
15+
None if bias is None else bias._underlying,
16+
)
17+
)
18+
19+
_infinicore.linear_(
20+
out._underlying,
21+
input._underlying,
22+
weight._underlying,
23+
None if bias is None else bias._underlying,
24+
)
25+
return out
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include "infinicore/ops/linear.hpp"
2+
#include "infinicore/ops/add.hpp"
3+
#include "infinicore/ops/matmul.hpp"
4+
5+
namespace infinicore::op {
6+
7+
Tensor linear(Tensor input,
8+
Tensor weight,
9+
std::optional<Tensor> bias) {
10+
11+
Size ndim = input->ndim();
12+
Size out_features = weight->shape()[0];
13+
14+
// Assign memory to out variables
15+
auto output_shape = input->shape();
16+
output_shape[ndim - 1] = out_features;
17+
auto out = Tensor::empty(output_shape, input->dtype(), input->device());
18+
19+
// Inplace Calculate
20+
linear_(out, input, weight, bias);
21+
return out;
22+
}
23+
24+
void linear_(Tensor out,
25+
Tensor input,
26+
Tensor weight,
27+
std::optional<Tensor> bias) {
28+
29+
auto weight_shape = weight->shape();
30+
Size out_features = weight_shape[0];
31+
Size in_features = weight_shape[1];
32+
33+
Size ndim = input->ndim();
34+
assert(out->ndim() == ndim);
35+
36+
// Calculate the number of features
37+
Size N = 1;
38+
auto input_shape = input->shape();
39+
for (size_t i = 0; i < ndim - 1; ++i) {
40+
N *= input_shape[i];
41+
}
42+
43+
// linear transformation
44+
Tensor out_view = out->view({N, out_features});
45+
matmul_(out_view,
46+
input->view({N, in_features}),
47+
weight->permute({1, 0}));
48+
49+
// Add bias
50+
if (bias.has_value()) {
51+
add_(out_view,
52+
out_view,
53+
bias.value()->as_strided({N, out_features}, {0, 1}));
54+
}
55+
}
56+
57+
} // namespace infinicore::op

src/infinicore/pybind11/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "ops/add.hpp"
66
#include "ops/attention.hpp"
77
#include "ops/causal_softmax.hpp"
8+
#include "ops/linear.hpp"
89
#include "ops/matmul.hpp"
910
#include "ops/rearrange.hpp"
1011
#include "ops/rms_norm.hpp"
@@ -19,6 +20,7 @@ inline void bind(py::module &m) {
1920
bind_add(m);
2021
bind_attention(m);
2122
bind_causal_softmax(m);
23+
bind_linear(m);
2224
bind_matmul(m);
2325
bind_rearrange(m);
2426
bind_rms_norm(m);
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#pragma once
2+
3+
#include "infinicore/ops/linear.hpp"
4+
5+
#include <pybind11/pybind11.h>
6+
7+
namespace py = pybind11;
8+
9+
namespace infinicore::ops::linear {
10+
11+
Tensor py_linear(Tensor input,
12+
Tensor weight,
13+
pybind11::object bias) {
14+
std::optional<Tensor> bias_tensor = std::nullopt;
15+
if (!bias.is_none()) {
16+
bias_tensor = bias.cast<Tensor>();
17+
}
18+
return op::linear(input, weight, bias_tensor);
19+
}
20+
21+
void py_linear_(Tensor out,
22+
Tensor input,
23+
Tensor weight,
24+
pybind11::object bias) {
25+
26+
std::optional<Tensor> bias_tensor = std::nullopt;
27+
if (!bias.is_none()) {
28+
bias_tensor = bias.cast<Tensor>();
29+
}
30+
31+
op::linear_(out, input, weight, bias_tensor);
32+
}
33+
34+
} // namespace infinicore::ops::linear
35+
36+
namespace infinicore::ops {
37+
38+
inline void bind_linear(py::module &m) {
39+
m.def("linear",
40+
&linear::py_linear,
41+
py::arg("input"),
42+
py::arg("weight"),
43+
py::arg("bias") = py::none(),
44+
R"doc(Applies a linear transformation to the incoming data: y=xA^T+b.)doc");
45+
46+
m.def("linear_",
47+
&linear::py_linear_,
48+
py::arg("out"),
49+
py::arg("input"),
50+
py::arg("weight"),
51+
py::arg("bias") = py::none(),
52+
R"doc(In-place, applies a linear transformation to the incoming data: y=xA^T+b.)doc");
53+
}
54+
55+
} // namespace infinicore::ops

test/infinicore/ops/linear.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
5+
6+
import torch
7+
import infinicore
8+
from framework.base import BaseOperatorTest, TensorSpec, TestCase
9+
from framework.runner import GenericTestRunner
10+
from framework.utils import is_broadcast
11+
12+
# ==============================================================================
13+
# Operator-specific configuration
14+
# ==============================================================================
15+
_TEST_CASES_DATA = [
16+
# bs, n, in_features, out_features, bias
17+
(1, 5, 2048, 5632, True, None, None, None),
18+
(1, 1, 2048, 32000, False, None, None, None),
19+
(2, 5, 2048, 5632, True, None, None, None),
20+
(2, 5, 256, 2048, False, None, None, None),
21+
(None, 5, 256, 2048, False, None, None, None),
22+
(None, 1, 2048, 5632, True, None, None, None),
23+
]
24+
25+
# Tolerance configuration
26+
_TOLERANCE_MAP = {
27+
infinicore.float16: {"atol": 0, "rtol": 1e-2},
28+
infinicore.float32: {"atol": 0, "rtol": 1e-3},
29+
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
30+
}
31+
32+
# Data types to test
33+
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
34+
35+
36+
def parse_test_cases():
37+
"""
38+
Parse test case data and return list of TestCase objects for linear operation.
39+
Each test case contains all necessary information for execution and validation.
40+
"""
41+
test_cases = []
42+
43+
for data in _TEST_CASES_DATA:
44+
bs = data[0]
45+
n, in_features, out_features = data[1], data[2], data[3]
46+
bias = data[4]
47+
input_strides = data[5] if len(data) > 5 else None
48+
weight_strides = data[6] if len(data) > 6 else None
49+
out_strides = data[7] if len(data) > 7 else None
50+
51+
# Determine shapes based on batch dimension
52+
if bs is None:
53+
input_shape = (n, in_features)
54+
weight_shape = (out_features, in_features)
55+
out_shape = (n, out_features)
56+
else:
57+
input_shape = (bs, n, in_features)
58+
weight_shape = (out_features, in_features)
59+
out_shape = (bs, n, out_features)
60+
61+
if bias is True:
62+
bias_shape = (out_features,)
63+
else:
64+
bias_shape = None
65+
66+
# Check if tensors support in-place operations
67+
c_supports_inplace = not is_broadcast(out_shape)
68+
69+
# Generate test cases for all data types
70+
for dtype in _TENSOR_DTYPES:
71+
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
72+
73+
# Create typed tensor specs
74+
input_spec = TensorSpec.from_tensor(input_shape, input_strides, dtype)
75+
weight_spec = TensorSpec.from_tensor(weight_shape, weight_strides, dtype)
76+
out_spec = TensorSpec.from_tensor(out_shape, out_strides, dtype)
77+
78+
if bias_shape is not None:
79+
bias_spec = TensorSpec.from_tensor(bias_shape, None, dtype)
80+
else:
81+
bias_spec = None
82+
83+
# Test Case 1: Out-of-place (return value)
84+
test_cases.append(
85+
TestCase(
86+
inputs=[input_spec, weight_spec, bias_spec],
87+
kwargs={},
88+
output_spec=None,
89+
comparison_target=None,
90+
tolerance=tolerance,
91+
description=f"Linear - OUT_OF_PLACE",
92+
)
93+
)
94+
95+
# Test Case 2: In-place with explicit output tensor (Linear(a, b, out=c))
96+
if c_supports_inplace:
97+
test_cases.append(
98+
TestCase(
99+
inputs=[input_spec, weight_spec, bias_spec],
100+
kwargs=None,
101+
output_spec=out_spec, # Specify the output tensor spec
102+
comparison_target="out",
103+
tolerance=tolerance,
104+
description=f"Linear - INPLACE(out)",
105+
)
106+
)
107+
108+
return test_cases
109+
110+
111+
class OpTest(BaseOperatorTest):
112+
"""Linear operator test with simplified implementation"""
113+
114+
def __init__(self):
115+
super().__init__("Linear")
116+
117+
def get_test_cases(self):
118+
return parse_test_cases()
119+
120+
def torch_operator(self, input, weight, bias, out=None, **kwargs):
121+
"""PyTorch linear implementation"""
122+
result = torch.nn.functional.linear(input, weight, bias)
123+
if out is not None:
124+
out.copy_(result)
125+
return out
126+
return result
127+
128+
def infinicore_operator(self, input, weight, bias, out=None, **kwargs):
129+
"""InfiniCore linear implementation"""
130+
return infinicore.nn.functional.linear(input, weight, bias, out=out)
131+
132+
133+
def main():
134+
"""Main entry point"""
135+
runner = GenericTestRunner(OpTest)
136+
runner.run_and_exit()
137+
138+
139+
if __name__ == "__main__":
140+
main()

0 commit comments

Comments
 (0)