Skip to content

Commit e6b047c

Browse files
author
pengcheng888
committed
issue/581 - 调增py的linear的实现位置; 测试时,修改为新的调用
1 parent 3567f92 commit e6b047c

File tree

8 files changed

+300
-8
lines changed

8 files changed

+300
-8
lines changed

include/infinicore/ops/linear.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include "common/op.hpp"
4+
5+
namespace infinicore::op {
6+
7+
Tensor linear(Tensor input, Tensor weight, std::optional<Tensor> bias);
8+
9+
void linear_(Tensor out, Tensor input, Tensor weight, std::optional<Tensor> bias);
10+
11+
} // namespace infinicore::op
Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
from .causal_softmax import causal_softmax
2+
from .linear import linear
23
from .random_sample import random_sample
34
from .rms_norm import rms_norm
45
from .silu import silu
56
from .swiglu import swiglu
67

7-
__all__ = [
8-
"causal_softmax",
9-
"random_sample",
10-
"rms_norm",
11-
"silu",
12-
"swiglu",
13-
]
8+
__all__ = ["causal_softmax", "random_sample", "rms_norm", "silu", "swiglu", "linear"]
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
__all__ = ["linear"]
5+
6+
7+
def linear(input: Tensor, weight: Tensor, bias=None, *, out=None) -> Tensor:
8+
r"""Applies a linear transformation to the incoming data: y=xA^T+b."""
9+
10+
if out is None:
11+
return Tensor(
12+
_infinicore.linear(
13+
input._underlying,
14+
weight._underlying,
15+
None if bias is None else bias._underlying,
16+
)
17+
)
18+
19+
_infinicore.linear_(
20+
out._underlying,
21+
input._underlying,
22+
weight._underlying,
23+
None if bias is None else bias._underlying,
24+
)
25+
return out
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include "infinicore/ops/linear.hpp"
2+
#include "infinicore/ops/add.hpp"
3+
#include "infinicore/ops/matmul.hpp"
4+
5+
namespace infinicore::op {
6+
7+
Tensor linear(Tensor input,
8+
Tensor weight,
9+
std::optional<Tensor> bias) {
10+
11+
Size ndim = input->ndim();
12+
Size out_features = weight->shape()[0];
13+
14+
// Assign memory to out variables
15+
auto output_shape = input->shape();
16+
output_shape[ndim - 1] = out_features;
17+
auto out = Tensor::empty(output_shape, input->dtype(), input->device());
18+
19+
// Inplace Calculate
20+
linear_(out, input, weight, bias);
21+
return out;
22+
}
23+
24+
void linear_(Tensor out,
25+
Tensor input,
26+
Tensor weight,
27+
std::optional<Tensor> bias) {
28+
29+
auto weight_shape = weight->shape();
30+
Size out_features = weight_shape[0];
31+
Size in_features = weight_shape[1];
32+
33+
Size ndim = input->ndim();
34+
assert(out->ndim() == ndim);
35+
36+
// Calculate the number of features
37+
Size N = 1;
38+
auto input_shape = input->shape();
39+
for (size_t i = 0; i < ndim - 1; ++i) {
40+
N *= input_shape[i];
41+
}
42+
43+
// linear transformation
44+
Tensor out_view = out->view({N, out_features});
45+
matmul_(out_view,
46+
input->view({N, in_features}),
47+
weight->permute({1, 0}));
48+
49+
// Add bias
50+
if (bias.has_value()) {
51+
add_(out_view,
52+
out_view,
53+
bias.value()->as_strided({N, out_features}, {0, 1}));
54+
}
55+
}
56+
57+
} // namespace infinicore::op

src/infinicore/pybind11/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "ops/add.hpp"
66
#include "ops/attention.hpp"
77
#include "ops/causal_softmax.hpp"
8+
#include "ops/linear.hpp"
89
#include "ops/matmul.hpp"
910
#include "ops/mul.hpp"
1011
#include "ops/random_sample.hpp"
@@ -22,6 +23,7 @@ inline void bind(py::module &m) {
2223
bind_attention(m);
2324
bind_causal_softmax(m);
2425
bind_random_sample(m);
26+
bind_linear(m);
2527
bind_matmul(m);
2628
bind_mul(m);
2729
bind_rearrange(m);
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#pragma once
2+
3+
#include "infinicore/ops/linear.hpp"
4+
5+
#include <pybind11/pybind11.h>
6+
7+
namespace py = pybind11;
8+
9+
namespace infinicore::ops {
10+
11+
Tensor py_linear(Tensor input,
12+
Tensor weight,
13+
pybind11::object bias) {
14+
std::optional<Tensor> bias_tensor = std::nullopt;
15+
if (!bias.is_none()) {
16+
bias_tensor = bias.cast<Tensor>();
17+
}
18+
return op::linear(input, weight, bias_tensor);
19+
}
20+
21+
void py_linear_(Tensor out,
22+
Tensor input,
23+
Tensor weight,
24+
pybind11::object bias) {
25+
26+
std::optional<Tensor> bias_tensor = std::nullopt;
27+
if (!bias.is_none()) {
28+
bias_tensor = bias.cast<Tensor>();
29+
}
30+
31+
op::linear_(out, input, weight, bias_tensor);
32+
}
33+
34+
inline void bind_linear(py::module &m) {
35+
36+
m.def("linear",
37+
&ops::py_linear,
38+
py::arg("input"),
39+
py::arg("weight"),
40+
py::arg("bias") = py::none(),
41+
R"doc(Applies a linear transformation to the incoming data: y=xA^T+b.)doc");
42+
43+
m.def("linear_",
44+
&ops::py_linear_,
45+
py::arg("out"),
46+
py::arg("input"),
47+
py::arg("weight"),
48+
py::arg("bias") = py::none(),
49+
R"doc(In-place, applies a linear transformation to the incoming data: y=xA^T+b.)doc");
50+
}
51+
52+
} // namespace infinicore::ops

test/infinicore/framework/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,6 @@ def __str__(self):
351351
else:
352352
strides_str = f", strides={self.strides}" if self.strides else ""
353353
dtype_str = (
354-
f", {str(self.dtype).replace("infinicore.", "")}" if self.dtype else ""
354+
f", {str(self.dtype).replace('infinicore.', '')}" if self.dtype else ""
355355
)
356356
return f"{name_str}tensor{self.shape}{strides_str}{dtype_str}"

test/infinicore/ops/linear.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import os
2+
import sys
3+
4+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
5+
6+
import torch
7+
from framework.base import BaseOperatorTest, TensorSpec, TestCase
8+
from framework.runner import GenericTestRunner
9+
from framework.utils import is_broadcast
10+
11+
import infinicore
12+
13+
# ==============================================================================
14+
# Operator-specific configuration
15+
# ==============================================================================
16+
_TEST_CASES_DATA = [
17+
# bs, n, in_features, out_features, bias
18+
(1, 5, 2048, 5632, True, None, None, None),
19+
(1, 1, 2048, 32000, False, None, None, None),
20+
(2, 5, 2048, 5632, True, None, None, None),
21+
(2, 5, 256, 2048, False, None, None, None),
22+
(None, 5, 256, 2048, False, None, None, None),
23+
(None, 1, 2048, 5632, True, None, None, None),
24+
]
25+
26+
# Tolerance configuration
27+
_TOLERANCE_MAP = {
28+
infinicore.float16: {"atol": 0, "rtol": 1e-2},
29+
infinicore.float32: {"atol": 0, "rtol": 1e-3},
30+
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
31+
}
32+
33+
# Data types to test
34+
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
35+
36+
37+
def parse_test_cases():
38+
"""
39+
Parse test case data and return list of TestCase objects for linear operation.
40+
Each test case contains all necessary information for execution and validation.
41+
"""
42+
test_cases = []
43+
44+
for data in _TEST_CASES_DATA:
45+
bs = data[0]
46+
n, in_features, out_features = data[1], data[2], data[3]
47+
bias = data[4]
48+
input_strides = data[5] if len(data) > 5 else None
49+
weight_strides = data[6] if len(data) > 6 else None
50+
out_strides = data[7] if len(data) > 7 else None
51+
52+
# Determine shapes based on batch dimension
53+
if bs is None:
54+
input_shape = (n, in_features)
55+
weight_shape = (out_features, in_features)
56+
out_shape = (n, out_features)
57+
else:
58+
input_shape = (bs, n, in_features)
59+
weight_shape = (out_features, in_features)
60+
out_shape = (bs, n, out_features)
61+
62+
if bias is True:
63+
bias_shape = (out_features,)
64+
else:
65+
bias_shape = None
66+
67+
# Check if tensors support in-place operations
68+
c_supports_inplace = not is_broadcast(out_shape)
69+
70+
# Generate test cases for all data types
71+
for dtype in _TENSOR_DTYPES:
72+
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
73+
74+
# Create typed tensor specs
75+
input_spec = TensorSpec.from_tensor(input_shape, input_strides, dtype)
76+
weight_spec = TensorSpec.from_tensor(weight_shape, weight_strides, dtype)
77+
out_spec = TensorSpec.from_tensor(out_shape, out_strides, dtype)
78+
79+
if bias_shape is not None:
80+
bias_spec = TensorSpec.from_tensor(bias_shape, None, dtype)
81+
else:
82+
bias_spec = None
83+
84+
# Test Case 1: Out-of-place (return value)
85+
test_cases.append(
86+
TestCase(
87+
inputs=[input_spec, weight_spec, bias_spec],
88+
kwargs={},
89+
output_spec=None,
90+
comparison_target=None,
91+
tolerance=tolerance,
92+
description=f"Linear - OUT_OF_PLACE",
93+
)
94+
)
95+
96+
# Test Case 2: In-place with explicit output tensor (Linear(a, b, out=c))
97+
if c_supports_inplace:
98+
test_cases.append(
99+
TestCase(
100+
inputs=[input_spec, weight_spec, bias_spec],
101+
kwargs=None,
102+
output_spec=out_spec, # Specify the output tensor spec
103+
comparison_target="out",
104+
tolerance=tolerance,
105+
description=f"Linear - INPLACE(out)",
106+
)
107+
)
108+
109+
return test_cases
110+
111+
112+
class OpTest(BaseOperatorTest):
113+
"""Linear operator test with simplified implementation"""
114+
115+
def __init__(self):
116+
super().__init__("Linear")
117+
118+
def get_test_cases(self):
119+
return parse_test_cases()
120+
121+
def torch_operator(self, input, weight, bias, out=None, **kwargs):
122+
"""PyTorch linear implementation"""
123+
result = torch.nn.functional.linear(input, weight, bias)
124+
if out is not None:
125+
out.copy_(result)
126+
return out
127+
return result
128+
129+
def infinicore_operator(self, input, weight, bias, out=None, **kwargs):
130+
"""InfiniCore linear implementation"""
131+
return infinicore.nn.functional.linear(input, weight, bias, out=out)
132+
133+
# def torch_operator(self, *args, **kwargs):
134+
# """PyTorch linear implementation"""
135+
136+
# return torch.nn.functional.linear(*args, **kwargs)
137+
138+
# def infinicore_operator(self, *args, **kwargs):
139+
# """InfiniCore linear implementation"""
140+
# return infinicore.nn.functional.linear(input, *args, **kwargs)
141+
142+
143+
def main():
144+
"""Main entry point"""
145+
runner = GenericTestRunner(OpTest)
146+
runner.run_and_exit()
147+
148+
149+
if __name__ == "__main__":
150+
main()

0 commit comments

Comments
 (0)