@@ -456,20 +456,23 @@ end
456456
457457import ClimaComms
458458
459- ClimaComms. array_type (x:: FieldVector ) =
460- promote_type (unrolled_map (ClimaComms. array_type, _values (x))... )
461-
462- ClimaComms. device (x:: FieldVector ) = ClimaComms. device (ClimaComms. context (x))
463- function ClimaComms. context (x:: FieldVector )
464- isempty (_values (x)) && error (" Empty FieldVector has no device or context" )
465- # We don't have promotion for devices or contexts, so we use the first value
466- # that isn't a PointField (a PointField's data can be stored on a different
467- # device from other Fields to avoid scalar indexing on GPUs). If there is no
468- # such value, fall back to using the first PointField.
469- index = unrolled_findfirst (Base. Fix1 (! isa, PointField), _values (x))
470- return ClimaComms. context (_values (x)[isnothing (index) ? 1 : index])
459+ # To infer the ClimaComms device and its properties, use the first Field in a
460+ # FieldVector that isn't a PointField, since a PointField's data can be stored
461+ # on a different device from other Fields to avoid scalar indexing on GPUs. If
462+ # the FieldVector only contains PointFields, fall back to using the first one.
463+ function representative_field (x)
464+ all_fields = _values (x)
465+ isempty (all_fields) && error (" Empty FieldVector has no ClimaComms device" )
466+ field_index = unrolled_findfirst (Base. Fix2 (! isa, PointField), all_fields)
467+ return all_fields[isnothing (field_index) ? 1 : field_index]
471468end
472469
470+ ClimaComms. array_type (x:: FieldVector ) =
471+ ClimaComms. array_type (representative_field (x))
472+ ClimaComms. device (x:: FieldVector ) = ClimaComms. device (representative_field (x))
473+ ClimaComms. context (x:: FieldVector ) = ClimaComms. context (representative_field (x))
474+
475+
473476function __rprint_diff (
474477 io:: IO ,
475478 x:: T ,
0 commit comments