@@ -149,7 +149,7 @@ def update(engine, i):
149149
150150 engine = Engine (update )
151151
152- device = "cpu" if idist .device ().type == "xla" else idist .device ()
152+ device = torch . device ( "cpu" ) if idist .device ().type == "xla" else idist .device ()
153153 metric = RocCurve (device = device )
154154 metric .attach (engine , "roc_curve" )
155155
@@ -159,10 +159,14 @@ def update(engine, i):
159159
160160 fpr , tpr , thresholds = engine .state .metrics ["roc_curve" ]
161161
162+ assert isinstance (fpr , torch .Tensor ) and fpr .device == device
163+ assert isinstance (tpr , torch .Tensor ) and tpr .device == device
164+ assert isinstance (thresholds , torch .Tensor ) and thresholds .device == device
165+
162166 y = idist .all_gather (y )
163167 y_pred = idist .all_gather (y_pred )
164- sk_fpr , sk_tpr , sk_thresholds = roc_curve (y , y_pred )
168+ sk_fpr , sk_tpr , sk_thresholds = roc_curve (y . cpu (). numpy () , y_pred . cpu (). numpy () )
165169
166- assert np .array_equal (fpr , sk_fpr )
167- assert np .array_equal (tpr , sk_tpr )
168- np .testing .assert_array_almost_equal (thresholds , sk_thresholds )
170+ np .testing . assert_array_almost_equal (fpr . cpu (). numpy () , sk_fpr )
171+ np .testing . assert_array_almost_equal (tpr . cpu (). numpy () , sk_tpr )
172+ np .testing .assert_array_almost_equal (thresholds . cpu (). numpy () , sk_thresholds )
0 commit comments