Skip to content
This repository was archived by the owner on Aug 27, 2024. It is now read-only.

Commit ae3da06

Browse files
authored
patch duplicate function (#73)
Co-authored-by: Sam <sam.vanhoutte@marchitec.be>
1 parent eef7565 commit ae3da06

File tree

1 file changed

+26
-25
lines changed

1 file changed

+26
-25
lines changed

arcus/azureml/experimenting/aml_trainer.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -286,25 +286,6 @@ def evaluate_image_classifier(self, fitted_model, X_test: np.array, y_test: np.a
286286
if return_predictions:
287287
return y_pred
288288

289-
def save_image_outputs(self, X_test: np.array, y_test: np.array, y_pred: np.array, samples_to_save: int = 1) -> np.array:
290-
'''
291-
Will save image outputs to the run
292-
Args:
293-
X_test (np.array): The input images for the model
294-
y_test (np.array): The actual expected output images of the model
295-
y_pred (np.array): The predicted or calculated output images of the model
296-
samples_to_save (int): If greather than 0, this amount of input, output and generated image combinations will be tracked to the Run
297-
'''
298-
299-
if samples_to_save > 0:
300-
import random
301-
total_images = min(len(y_pred), samples_to_save)
302-
303-
for i in random.sample(range(len(y_pred)), total_images):
304-
newimg = self.concat_images([X_test[i], y_test[i], y_pred[i]])
305-
imgplot = explorer.show_image(newimg, silent_mode=True)
306-
self.__current_run.log_image(f'Image combo sample {i}', plot=imgplot)
307-
imgplot.close()
308289

309290

310291
def __stack_images(self, img1: np.array, img2: np.array):
@@ -365,7 +346,6 @@ def evaluate_image_classifier(self, fitted_model, X_test: np.array, y_test: np.a
365346
return y_pred
366347

367348
def save_image_outputs(self, X_test: np.array, y_test: np.array, y_pred: np.array, samples_to_save: int = 1) -> np.array:
368-
369349
'''
370350
Will save image outputs to the run
371351
Args:
@@ -376,15 +356,36 @@ def save_image_outputs(self, X_test: np.array, y_test: np.array, y_pred: np.arra
376356
'''
377357

378358
if samples_to_save > 0:
379-
# Take incorrect classified images and save
380359
import random
381360
total_images = min(len(y_pred), samples_to_save)
382361

383362
for i in random.sample(range(len(y_pred)), total_images):
384-
groupplot = explorer.visualize({'Charts': [X_test[i]], 'Actuals': [y_test[i]], 'Calculated': [y_pred[i]]}, 1, grid_size=(6, 6), silent_mode=True)
385-
image = X_test[i].reshape(X_test.shape[1], X_test.shape[2])
386-
imgplot = explorer.show_image(image, silent_mode=True)
387-
self.__current_run.log_image(f'Sample {i:02d} / {total_images:02d}', plot=groupplot)
363+
newimg = self.concat_images([X_test[i], y_test[i], y_pred[i]])
364+
imgplot = explorer.show_image(newimg, silent_mode=True)
365+
self.__current_run.log_image(f'Image combo sample {i}', plot=imgplot)
366+
imgplot.close()
367+
368+
# def save_image_outputs(self, X_test: np.array, y_test: np.array, y_pred: np.array, samples_to_save: int = 1) -> np.array:
369+
370+
# '''
371+
# Will save image outputs to the run
372+
# Args:
373+
# X_test (np.array): The input images for the model
374+
# y_test (np.array): The actual expected output images of the model
375+
# y_pred (np.array): The predicted or calculated output images of the model
376+
# samples_to_save (int): If greather than 0, this amount of input, output and generated image combinations will be tracked to the Run
377+
# '''
378+
379+
# if samples_to_save > 0:
380+
# # Take incorrect classified images and save
381+
# import random
382+
# total_images = min(len(y_pred), samples_to_save)
383+
384+
# for i in random.sample(range(len(y_pred)), total_images):
385+
# groupplot = explorer.visualize({'Charts': [X_test[i]], 'Actuals': [y_test[i]], 'Calculated': [y_pred[i]]}, 1, grid_size=(6, 6), silent_mode=True)
386+
# image = X_test[i].reshape(X_test.shape[1], X_test.shape[2])
387+
# imgplot = explorer.show_image(image, silent_mode=True)
388+
# self.__current_run.log_image(f'Sample {i:02d} / {total_images:02d}', plot=groupplot)
388389

389390
def setup_training(self, training_name: str, overwrite: bool = False):
390391
'''

0 commit comments

Comments
 (0)