Skip to content

Commit 9b0b89c

Browse files
authored
Merge pull request #592 from InfiniTensor/issue/591
issue/591 添加infinicore.narrow
2 parents 5028ea4 + 16854ae commit 9b0b89c

File tree

17 files changed

+126
-18
lines changed

17 files changed

+126
-18
lines changed

include/infinicore/context/context.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Device getDevice();
1616
size_t getDeviceCount(Device::Type type);
1717

1818
infinirtStream_t getStream();
19-
infiniopHandle_t getInfiniopHandle();
19+
infiniopHandle_t getInfiniopHandle(Device device);
2020

2121
void syncStream();
2222
void syncDevice();

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from infinicore.ops.attention import attention
3232
from infinicore.ops.matmul import matmul
3333
from infinicore.ops.mul import mul
34+
from infinicore.ops.narrow import narrow
3435
from infinicore.ops.rearrange import rearrange
3536
from infinicore.tensor import (
3637
Tensor,
@@ -79,6 +80,7 @@
7980
"attention",
8081
"matmul",
8182
"mul",
83+
"narrow",
8284
"rearrange",
8385
"empty",
8486
"empty_like",

python/infinicore/ops/narrow.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from infinicore.tensor import Tensor
2+
3+
4+
def narrow(input: Tensor, dim: int, start: int, length: int) -> Tensor:
5+
return Tensor(input._underlying.narrow(dim, start, length))

python/infinicore/tensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def numel(self):
5656
def is_contiguous(self):
5757
return self._underlying.is_contiguous()
5858

59-
def is_is_pinned(self):
60-
return self._underlying.is_is_pinned()
59+
def is_pinned(self):
60+
return self._underlying.is_pinned()
6161

6262
def copy_(self, src):
6363
self._underlying.copy_(src._underlying)
@@ -67,12 +67,12 @@ def to(self, *args, **kwargs):
6767
self._underlying.to(*tuple(arg._underlying for arg in args), **kwargs)
6868
)
6969

70-
def as_strided(self, size, stride):
71-
return Tensor(self._underlying.as_strided(size, stride))
72-
7370
def contiguous(self):
7471
return Tensor(self._underlying.contiguous())
7572

73+
def as_strided(self, size, stride):
74+
return Tensor(self._underlying.as_strided(size, stride))
75+
7676
def permute(self, dims):
7777
return Tensor(self._underlying.permute(dims))
7878

src/infinicore/context/context_impl.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,13 @@ infinirtStream_t getStream() {
9999
return ContextImpl::singleton().getCurrentRuntime()->stream();
100100
}
101101

102-
infiniopHandle_t getInfiniopHandle() {
102+
infiniopHandle_t getInfiniopHandle(Device device) {
103+
if (device.getType() == Device::Type::CPU) {
104+
return ContextImpl::singleton().getCpuRuntime()->infiniopHandle();
105+
}
106+
if (device != getDevice()) {
107+
throw std::runtime_error("Requested device doesn't match current runtime.");
108+
}
103109
return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle();
104110
}
105111

src/infinicore/ops/add/add_infiniop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {
2828

2929
if (!desc_opt) {
3030
INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor(
31-
context::getInfiniopHandle(), &desc,
31+
context::getInfiniopHandle(c->device()), &desc,
3232
c->desc(), a->desc(), b->desc()));
3333
cache.put(seed, desc);
3434
} else {

src/infinicore/ops/attention/attention_infiniop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void calculate(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor
2828

2929
if (!desc_opt) {
3030
INFINICORE_CHECK_ERROR(infiniopCreateAttentionDescriptor(
31-
context::getInfiniopHandle(), &desc,
31+
context::getInfiniopHandle(out->device()), &desc,
3232
out->desc(), q->desc(), k->desc(), v->desc(),
3333
k_cache->desc(), v_cache->desc(), pos));
3434
cache.put(seed, desc);

src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) {
2828

2929
if (!desc_opt) {
3030
INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor(
31-
context::getInfiniopHandle(), &desc,
31+
context::getInfiniopHandle(output->device()), &desc,
3232
output->desc(), input->desc()));
3333
cache.put(seed, desc);
3434
} else {

src/infinicore/ops/gemm/gemm_infiniop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
2828

2929
if (!desc_opt) {
3030
INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor(
31-
context::getInfiniopHandle(), &desc,
31+
context::getInfiniopHandle(c->device()), &desc,
3232
c->desc(), a->desc(), b->desc()));
3333
cache.put(seed, desc);
3434
} else {

src/infinicore/ops/mul/mul_infiniop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {
2828

2929
if (!desc_opt) {
3030
INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor(
31-
context::getInfiniopHandle(), &desc,
31+
context::getInfiniopHandle(c->device()), &desc,
3232
c->desc(), a->desc(), b->desc()));
3333
cache.put(seed, desc);
3434
} else {

0 commit comments

Comments
 (0)