Skip to content

Commit 16854ae

Browse files
committed
issue/591 - fix operator context mismatch
1 parent a311e9c commit 16854ae

File tree

12 files changed

+18
-12
lines changed

12 files changed

+18
-12
lines changed

include/infinicore/context/context.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Device getDevice();
1616
size_t getDeviceCount(Device::Type type);
1717

1818
infinirtStream_t getStream();
19-
infiniopHandle_t getInfiniopHandle();
19+
infiniopHandle_t getInfiniopHandle(Device device);
2020

2121
void syncStream();
2222
void syncDevice();

src/infinicore/context/context_impl.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,13 @@ infinirtStream_t getStream() {
9999
return ContextImpl::singleton().getCurrentRuntime()->stream();
100100
}
101101

102-
infiniopHandle_t getInfiniopHandle() {
102+
infiniopHandle_t getInfiniopHandle(Device device) {
103+
if (device.getType() == Device::Type::CPU) {
104+
return ContextImpl::singleton().getCpuRuntime()->infiniopHandle();
105+
}
106+
if (device != getDevice()) {
107+
throw std::runtime_error("Requested device doesn't match current runtime.");
108+
}
103109
return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle();
104110
}
105111

src/infinicore/ops/add/add_infiniop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {
2828

2929
if (!desc_opt) {
3030
INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor(
31-
context::getInfiniopHandle(), &desc,
31+
context::getInfiniopHandle(c->device()), &desc,
3232
c->desc(), a->desc(), b->desc()));
3333
cache.put(seed, desc);
3434
} else {

src/infinicore/ops/attention/attention_infiniop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void calculate(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor
2828

2929
if (!desc_opt) {
3030
INFINICORE_CHECK_ERROR(infiniopCreateAttentionDescriptor(
31-
context::getInfiniopHandle(), &desc,
31+
context::getInfiniopHandle(out->device()), &desc,
3232
out->desc(), q->desc(), k->desc(), v->desc(),
3333
k_cache->desc(), v_cache->desc(), pos));
3434
cache.put(seed, desc);

src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) {
2828

2929
if (!desc_opt) {
3030
INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor(
31-
context::getInfiniopHandle(), &desc,
31+
context::getInfiniopHandle(output->device()), &desc,
3232
output->desc(), input->desc()));
3333
cache.put(seed, desc);
3434
} else {

src/infinicore/ops/gemm/gemm_infiniop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
2828

2929
if (!desc_opt) {
3030
INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor(
31-
context::getInfiniopHandle(), &desc,
31+
context::getInfiniopHandle(c->device()), &desc,
3232
c->desc(), a->desc(), b->desc()));
3333
cache.put(seed, desc);
3434
} else {

src/infinicore/ops/mul/mul_infiniop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {
2828

2929
if (!desc_opt) {
3030
INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor(
31-
context::getInfiniopHandle(), &desc,
31+
context::getInfiniopHandle(c->device()), &desc,
3232
c->desc(), a->desc(), b->desc()));
3333
cache.put(seed, desc);
3434
} else {

src/infinicore/ops/rearrange/rearrange_infiniop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ void calculate(Tensor y, Tensor x) {
2727
infiniopRearrangeDescriptor_t desc = nullptr;
2828

2929
if (!desc_opt) {
30-
INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(), &desc, y->desc(), x->desc()));
30+
INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(y->device()), &desc, y->desc(), x->desc()));
3131
cache.put(seed, desc);
3232
} else {
3333
desc = *desc_opt;

src/infinicore/ops/rms_norm/rms_norm_infiniop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) {
2828

2929
if (!desc_opt) {
3030
INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor(
31-
context::getInfiniopHandle(), &desc,
31+
context::getInfiniopHandle(y->device()), &desc,
3232
y->desc(), x->desc(), weight->desc(), epsilon));
3333
cache.put(seed, desc);
3434
} else {

src/infinicore/ops/rope/rope_infiniop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &s
4242

4343
if (!desc_opt) {
4444
INFINICORE_CHECK_ERROR(infiniopCreateRoPEDescriptor(
45-
context::getInfiniopHandle(), &desc,
45+
context::getInfiniopHandle(x_out->device()), &desc,
4646
x_out->desc(), x->desc(),
4747
pos->desc(), sin_cache->desc(), cos_cache->desc(),
4848
infiniop_algo));

0 commit comments

Comments
 (0)