@@ -399,10 +399,15 @@ function _Remapper(
399399 )
400400 num_dims = num_hdims
401401 else
402+ cpu_space = if ClimaComms. device (space) isa AbstractCPUDevice
403+ space
404+ else
405+ to_cpu (space)
406+ end
402407 vert_interpolation_weights =
403- ArrayType (vertical_interpolation_weights (space , target_zcoords))
408+ ArrayType (vertical_interpolation_weights (cpu_space , target_zcoords))
404409 vert_bounding_indices =
405- ArrayType (vertical_bounding_indices (space , target_zcoords))
410+ ArrayType (vertical_bounding_indices (cpu_space , target_zcoords))
406411
407412 # We have to add one extra dimension with respect to the bitmask/local_horiz_indices
408413 # because we are going to store the values for the columns
@@ -463,10 +468,16 @@ function _Remapper(
463468 FT = Spaces. undertype (space)
464469 ArrayType = ClimaComms. array_type (space)
465470
471+ cpu_space = if ClimaComms. device (space) isa AbstractCPUDevice
472+ space
473+ else
474+ to_cpu (space)
475+ end
476+
466477 vert_interpolation_weights =
467- ArrayType (vertical_interpolation_weights (space , target_zcoords))
478+ ArrayType (vertical_interpolation_weights (cpu_space , target_zcoords))
468479 vert_bounding_indices =
469- ArrayType (vertical_bounding_indices (space , target_zcoords))
480+ ArrayType (vertical_bounding_indices (cpu_space , target_zcoords))
470481
471482 local_interpolated_values =
472483 ArrayType (zeros (FT, (length (target_zcoords), buffer_length)))
0 commit comments