@@ -71,10 +71,6 @@ def classifier_two_sample_test(
7171 full training history.
7272 """
7373
74- # Convert tensors to numpy, if passed
75- estimates = keras .ops .convert_to_numpy (estimates )
76- targets = keras .ops .convert_to_numpy (targets )
77-
7874 # Error, if targets dim does not match estimates dim
7975 num_dims = estimates .shape [1 ]
8076 if not num_dims == targets .shape [1 ]:
@@ -111,15 +107,30 @@ def classifier_two_sample_test(
111107 monitor = f"val_{ metric } " , patience = patience , restore_best_weights = True
112108 )
113109
114- history = classifier .fit (
115- x = data ,
116- y = labels ,
117- epochs = max_epochs ,
118- batch_size = batch_size ,
119- verbose = 0 ,
120- callbacks = [early_stopping ],
121- validation_split = validation_split ,
122- )
110+ # For now, we need to enable grads, since we turn them off by default
111+ if keras .backend .backend () == "torch" :
112+ import torch
113+
114+ with torch .enable_grad ():
115+ history = classifier .fit (
116+ x = data ,
117+ y = labels ,
118+ epochs = max_epochs ,
119+ batch_size = batch_size ,
120+ verbose = 0 ,
121+ callbacks = [early_stopping ],
122+ validation_split = validation_split ,
123+ )
124+ else :
125+ history = classifier .fit (
126+ x = data ,
127+ y = labels ,
128+ epochs = max_epochs ,
129+ batch_size = batch_size ,
130+ verbose = 0 ,
131+ callbacks = [early_stopping ],
132+ validation_split = validation_split ,
133+ )
123134
124135 if return_metric_only :
125136 return history .history [f"val_{ metric } " ][- 1 ]
0 commit comments