@@ -591,13 +591,13 @@ auto MakeTensorView(Context const *ctx, Order order, common::Span<T, ext> data,
591
591
592
592
template <typename T, typename ... S>
593
593
auto MakeTensorView (Context const *ctx, HostDeviceVector<T> *data, S &&...shape) {
594
- auto span = ctx->IsCUDA () ? data->DeviceSpan () : data->HostSpan ();
594
+ auto span = ctx->IsCPU () ? data->HostSpan () : data->DeviceSpan ();
595
595
return MakeTensorView (ctx->Device (), span, std::forward<S>(shape)...);
596
596
}
597
597
598
598
template <typename T, typename ... S>
599
599
auto MakeTensorView (Context const *ctx, HostDeviceVector<T> const *data, S &&...shape) {
600
- auto span = ctx->IsCUDA () ? data->ConstDeviceSpan () : data->ConstHostSpan ();
600
+ auto span = ctx->IsCPU () ? data->ConstHostSpan () : data->ConstDeviceSpan ();
601
601
return MakeTensorView (ctx->Device (), span, std::forward<S>(shape)...);
602
602
}
603
603
@@ -647,13 +647,13 @@ auto MakeVec(T *ptr, size_t s, DeviceOrd device = DeviceOrd::CPU()) {
647
647
648
648
template <typename T>
649
649
auto MakeVec (HostDeviceVector<T> *data) {
650
- return MakeVec (data->Device ().IsCUDA () ? data->DevicePointer () : data->HostPointer (),
650
+ return MakeVec (data->Device ().IsCPU () ? data->HostPointer () : data->DevicePointer (),
651
651
data->Size (), data->Device ());
652
652
}
653
653
654
654
template <typename T>
655
655
auto MakeVec (HostDeviceVector<T> const *data) {
656
- return MakeVec (data->Device ().IsCUDA () ? data->ConstDevicePointer () : data->ConstHostPointer (),
656
+ return MakeVec (data->Device ().IsCPU () ? data->ConstHostPointer () : data->ConstDevicePointer (),
657
657
data->Size (), data->Device ());
658
658
}
659
659
@@ -759,7 +759,7 @@ class Tensor {
759
759
for (auto i = D; i < kDim ; ++i) {
760
760
shape_[i] = 1 ;
761
761
}
762
- if (device.IsCUDA ()) {
762
+ if (! device.IsCPU ()) {
763
763
data_.SetDevice (device);
764
764
data_.ConstDevicePointer (); // Pull to device;
765
765
}
@@ -788,11 +788,11 @@ class Tensor {
788
788
shape_[i] = 1 ;
789
789
}
790
790
auto size = detail::CalcSize (shape_);
791
- if (device.IsCUDA ()) {
791
+ if (! device.IsCPU ()) {
792
792
data_.SetDevice (device);
793
793
}
794
794
data_.Resize (size);
795
- if (device.IsCUDA ()) {
795
+ if (! device.IsCPU ()) {
796
796
data_.DevicePointer (); // Pull to device
797
797
}
798
798
}
0 commit comments