Skip to content

Commit cb4cc2d

Browse files
committed
[CM-1579]: Fixed ExperimentContext handling by online experiment. Fixed related test cases.
1 parent ac117e5 commit cb4cc2d

File tree

6 files changed

+138
-50
lines changed

6 files changed

+138
-50
lines changed

comet-java-client/src/main/java/ml/comet/experiment/context/ExperimentContext.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,35 @@ public void mergeFrom(@NonNull ExperimentContext other) {
8282
}
8383
}
8484

85+
/**
86+
* Allows to check if this is empty context, i.e., has no value set.
87+
*
88+
* @return {@code true} if this is empty context.
89+
*/
90+
public boolean isEmpty() {
91+
return Objects.isNull(this.step) && Objects.isNull(this.epoch) && StringUtils.isBlank(this.context);
92+
}
93+
94+
/**
95+
* Indicates whether provided context object equals to this one.
96+
*
97+
* @param obj the instance of {@code ExperimentContext} to check.
98+
* @return {@code true} if provided context object equals to this one.
99+
*/
100+
public boolean equals(Object obj) {
101+
if (Objects.isNull(obj) || !(obj instanceof ExperimentContext)) {
102+
return false;
103+
}
104+
105+
if (this == obj) {
106+
return true;
107+
}
108+
ExperimentContext ctx = (ExperimentContext) obj;
109+
return Objects.equals(this.context, ctx.context)
110+
&& Objects.equals(this.epoch, ctx.epoch)
111+
&& Objects.equals(this.step, ctx.step);
112+
}
113+
85114
/**
86115
* The factory to return empty {@link ExperimentContext} instance.
87116
*

comet-java-client/src/main/java/ml/comet/experiment/impl/BaseExperimentAsync.java

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,12 @@ abstract class BaseExperimentAsync extends BaseExperiment {
8888
this.baseContext = ExperimentContext.empty();
8989
}
9090

91-
void updateContext(ExperimentContext context) {
92-
this.baseContext.mergeFrom(context);
91+
ExperimentContext mergeWithBaseContextIfEmpty(ExperimentContext context) {
92+
if (context.isEmpty()) {
93+
return new ExperimentContext(this.baseContext);
94+
} else {
95+
return context;
96+
}
9397
}
9498

9599
/**
@@ -104,13 +108,13 @@ void updateContext(ExperimentContext context) {
104108
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
105109
void logMetric(@NonNull String metricName, @NonNull Object metricValue,
106110
@NonNull ExperimentContext context, @NonNull Optional<Action> onComplete) {
107-
this.updateContext(context);
111+
ExperimentContext ctx = mergeWithBaseContextIfEmpty(context);
108112

109113
if (getLogger().isDebugEnabled()) {
110-
getLogger().debug("logMetricAsync {} = {}, context: {}", metricName, metricValue, context);
114+
getLogger().debug("logMetricAsync {} = {}, context: {}", metricName, metricValue, ctx);
111115
}
112116

113-
MetricRest metricRequest = createLogMetricRequest(metricName, metricValue, this.baseContext);
117+
MetricRest metricRequest = createLogMetricRequest(metricName, metricValue, ctx);
114118
this.sendAsynchronously(getRestApiClient()::logMetric, metricRequest, onComplete);
115119
}
116120

@@ -126,13 +130,13 @@ void logMetric(@NonNull String metricName, @NonNull Object metricValue,
126130
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
127131
void logParameter(@NonNull String parameterName, @NonNull Object paramValue,
128132
@NonNull ExperimentContext context, @NonNull Optional<Action> onComplete) {
129-
this.updateContext(context);
133+
ExperimentContext ctx = mergeWithBaseContextIfEmpty(context);
130134

131135
if (getLogger().isDebugEnabled()) {
132-
getLogger().debug("logParameterAsync {} = {}, context: {}", parameterName, paramValue, context);
136+
getLogger().debug("logParameterAsync {} = {}, context: {}", parameterName, paramValue, ctx);
133137
}
134138

135-
ParameterRest paramRequest = createLogParamRequest(parameterName, paramValue, this.baseContext);
139+
ParameterRest paramRequest = createLogParamRequest(parameterName, paramValue, ctx);
136140
this.sendAsynchronously(getRestApiClient()::logParameter, paramRequest, onComplete);
137141
}
138142

@@ -316,10 +320,7 @@ void logAssetFolder(@NonNull File folder, boolean logFilePath, boolean recursive
316320
getLogger().warn(getString(LOG_ASSET_FOLDER_EMPTY, folder));
317321
return;
318322
}
319-
this.updateContext(context);
320-
// make deep copy of the current experiment context to avoid side effects
321-
// if base experiment context become updated while operation is still in progress
322-
ExperimentContext assetContext = new ExperimentContext(this.baseContext);
323+
ExperimentContext assetContext = mergeWithBaseContextIfEmpty(context);
323324

324325
AtomicInteger successfullyLoggedCount = new AtomicInteger();
325326
try {
@@ -381,10 +382,10 @@ void logAssetFolder(@NonNull File folder, boolean logFilePath, boolean recursive
381382
void logRemoteAsset(@NonNull URI uri, Optional<String> logicalPath, boolean overwrite,
382383
@NonNull Optional<Map<String, Object>> metadata, @NonNull ExperimentContext context,
383384
@NonNull Optional<Action> onComplete) {
384-
this.updateContext(context);
385+
ExperimentContext ctx = mergeWithBaseContextIfEmpty(context);
385386

386387
RemoteAssetImpl asset = AssetUtils.createRemoteAsset(uri, logicalPath, overwrite, metadata, empty());
387-
this.logAssetAsync(getRestApiClient()::logRemoteAsset, asset, onComplete);
388+
this.logAssetAsync(getRestApiClient()::logRemoteAsset, asset, ctx, onComplete);
388389

389390
if (Objects.equals(asset.getLogicalPath(), AssetUtils.REMOTE_FILE_NAME_DEFAULT)) {
390391
getLogger().warn(
@@ -509,6 +510,7 @@ void updateArtifactVersionState(@NonNull LoggedArtifact loggedArtifact,
509510
* @param assetType the type of the asset.
510511
* @param groupingName optional name of group this asset should belong.
511512
* @param metadata the optional metadata to associate.
513+
* @param context the experiment context to be associated with given assets.
512514
* @param onComplete The optional action to be invoked when this operation asynchronously completes.
513515
* Can be {@code null} if not interested in completion signal.
514516
*/
@@ -519,11 +521,12 @@ void logAssetDataAsync(byte[] data, @NonNull String fileName, boolean overwrite,
519521
@NonNull Optional<Map<String, Object>> metadata,
520522
@NonNull ExperimentContext context,
521523
@NonNull Optional<Action> onComplete) {
522-
this.updateContext(context);
523524

524525
AssetImpl asset = createAssetFromData(data, fileName, overwrite, metadata, assetType);
525526
groupingName.ifPresent(asset::setGroupingName);
526-
this.logAssetAsync(asset, onComplete);
527+
ExperimentContext ctx = mergeWithBaseContextIfEmpty(context);
528+
529+
this.logAssetAsync(asset, ctx, onComplete);
527530
}
528531

529532
/**
@@ -535,6 +538,7 @@ void logAssetDataAsync(byte[] data, @NonNull String fileName, boolean overwrite,
535538
* @param assetType the type of the asset.
536539
* @param groupingName optional name of group this asset should belong.
537540
* @param metadata the optional metadata to associate.
541+
* @param context the experiment context to be associated with given assets.
538542
* @param onComplete The optional action to be invoked when this operation asynchronously completes.
539543
* Can be {@code null} if not interested in completion signal.
540544
*/
@@ -545,11 +549,12 @@ void logAssetFileAsync(@NonNull File file, @NonNull String fileName, boolean ove
545549
@NonNull Optional<Map<String, Object>> metadata,
546550
@NonNull ExperimentContext context,
547551
@NonNull Optional<Action> onComplete) {
548-
this.updateContext(context);
549552

550553
AssetImpl asset = createAssetFromFile(file, Optional.of(fileName), overwrite, metadata, assetType);
551554
groupingName.ifPresent(asset::setGroupingName);
552-
this.logAssetAsync(asset, onComplete);
555+
ExperimentContext ctx = mergeWithBaseContextIfEmpty(context);
556+
557+
this.logAssetAsync(asset, ctx, onComplete);
553558
}
554559

555560
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@@ -565,27 +570,31 @@ void logAssetFileAsync(@NonNull File file, @NonNull String fileName, boolean ove
565570
* Asynchronously logs provided asset and signals upload completion if {@code onComplete} action provided.
566571
*
567572
* @param asset the {@link Asset} to be uploaded.
573+
* @param context the experiment context to be associated with given assets.
568574
* @param onComplete the optional {@link Action} to be called upon operation completed,
569575
* either successful or failure.
570576
*/
571577
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
572-
private void logAssetAsync(@NonNull final Asset asset, @NonNull Optional<Action> onComplete) {
573-
this.logAssetAsync(getRestApiClient()::logAsset, asset, onComplete);
578+
private void logAssetAsync(@NonNull final Asset asset, @NonNull ExperimentContext context,
579+
@NonNull Optional<Action> onComplete) {
580+
this.logAssetAsync(getRestApiClient()::logAsset, asset, context, onComplete);
574581
}
575582

576583
/**
577584
* Attempts to log provided {@link AssetImpl} or its subclass asynchronously using specified log function.
578585
*
586+
* @param <T> the {@link AssetImpl} or its subclass.
579587
* @param func the function to be invoked to send asset to the backend.
580588
* @param asset the {@link AssetImpl} or subclass to be sent.
589+
* @param context the experiment context to be associated with given assets.
581590
* @param onComplete The optional action to be invoked when this operation
582591
* asynchronously completes. Can be empty if not interested in completion signal.
583-
* @param <T> the {@link AssetImpl} or its subclass.
584592
*/
585593
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
586594
private <T extends Asset> void logAssetAsync(@NonNull final BiFunction<T, String, Single<RestApiResponse>> func,
587-
@NonNull final T asset, @NonNull Optional<Action> onComplete) {
588-
((AssetImpl) asset).setContext(this.baseContext);
595+
@NonNull final T asset, @NonNull ExperimentContext context,
596+
@NonNull Optional<Action> onComplete) {
597+
((AssetImpl) asset).setContext(context);
589598
Single<RestApiResponse> single = this.sendAssetAsync(func, asset);
590599

591600
if (onComplete.isPresent()) {

comet-java-client/src/main/java/ml/comet/experiment/impl/OnlineExperimentImpl.java

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,12 @@ public void logMetric(String metricName, Object metricValue, long step) {
284284

285285
@Override
286286
public void logMetric(String metricName, Object metricValue) {
287-
this.logMetric(metricName, metricValue, this.baseContext);
287+
this.logMetric(metricName, metricValue, ExperimentContext.empty());
288288
}
289289

290290
@Override
291291
public void logParameter(@NonNull String parameterName, @NonNull Object paramValue) {
292-
this.logParameter(parameterName, paramValue, this.baseContext);
292+
this.logParameter(parameterName, paramValue, ExperimentContext.empty());
293293
}
294294

295295
@Override
@@ -368,7 +368,7 @@ public void logAssetFolder(@NonNull File folder, boolean logFilePath,
368368

369369
@Override
370370
public void logAssetFolder(@NonNull File folder, boolean logFilePath, boolean recursive) {
371-
this.logAssetFolder(folder, logFilePath, recursive, this.baseContext);
371+
this.logAssetFolder(folder, logFilePath, recursive, ExperimentContext.empty());
372372
}
373373

374374
@Override
@@ -404,12 +404,12 @@ public void uploadAsset(@NonNull File asset, @NonNull String logicalPath, boolea
404404

405405
@Override
406406
public void uploadAsset(@NonNull File asset, @NonNull String logicalPath, boolean overwrite) {
407-
this.uploadAsset(asset, logicalPath, overwrite, this.baseContext);
407+
this.uploadAsset(asset, logicalPath, overwrite, ExperimentContext.empty());
408408
}
409409

410410
@Override
411411
public void uploadAsset(@NonNull File asset, boolean overwrite) {
412-
this.uploadAsset(asset, asset.getName(), overwrite, this.baseContext);
412+
this.uploadAsset(asset, asset.getName(), overwrite, ExperimentContext.empty());
413413
}
414414

415415
@Override
@@ -422,7 +422,7 @@ public void logRemoteAsset(@NonNull URI uri, String logicalPath, boolean overwri
422422

423423
@Override
424424
public void logRemoteAsset(@NonNull URI uri, String logicalPath, boolean overwrite, Map<String, Object> metadata) {
425-
this.logRemoteAsset(uri, logicalPath, overwrite, metadata, this.baseContext);
425+
this.logRemoteAsset(uri, logicalPath, overwrite, metadata, ExperimentContext.empty());
426426
}
427427

428428
@Override
@@ -432,7 +432,7 @@ public void logRemoteAsset(@NonNull URI uri, @NonNull String logicalPath, boolea
432432

433433
@Override
434434
public void logRemoteAsset(@NonNull URI uri, boolean overwrite) {
435-
this.logRemoteAsset(uri, null, overwrite, null, this.baseContext);
435+
this.logRemoteAsset(uri, null, overwrite, null, ExperimentContext.empty());
436436
}
437437

438438
@Override
@@ -454,12 +454,27 @@ public void logCode(@NonNull File file, @NonNull ExperimentContext context) {
454454

455455
@Override
456456
public void logCode(@NonNull String code, @NonNull String logicalPath) {
457-
this.logCode(code, logicalPath, this.baseContext);
457+
this.logCode(code, logicalPath, ExperimentContext.empty());
458458
}
459459

460460
@Override
461461
public void logCode(@NonNull File file) {
462-
this.logCode(file, this.baseContext);
462+
this.logCode(file, ExperimentContext.empty());
463+
}
464+
465+
@Override
466+
public void logText(String text, ExperimentContext context, Map<String, Object> metadata) {
467+
468+
}
469+
470+
@Override
471+
public void logText(String text, ExperimentContext context) {
472+
473+
}
474+
475+
@Override
476+
public void logText(String text) {
477+
this.logText(text, ExperimentContext.empty());
463478
}
464479

465480
@Override
@@ -487,7 +502,7 @@ public void logModelFolder(@NonNull String modelName, @NonNull File folder, bool
487502
@Override
488503
public void logModelFolder(@NonNull String modelName, @NonNull File folder,
489504
boolean logFilePath, Map<String, Object> metadata) {
490-
this.logModelFolder(modelName, folder, logFilePath, metadata, this.baseContext);
505+
this.logModelFolder(modelName, folder, logFilePath, metadata, ExperimentContext.empty());
491506
}
492507

493508
@Override
@@ -513,7 +528,7 @@ public void logModel(@NonNull String modelName, @NonNull File file, @NonNull Str
513528
@Override
514529
public void logModel(@NonNull String modelName, @NonNull File file, @NonNull String logicalPath,
515530
boolean overwrite, Map<String, Object> metadata) {
516-
this.logModel(modelName, file, logicalPath, overwrite, metadata, this.baseContext);
531+
this.logModel(modelName, file, logicalPath, overwrite, metadata, ExperimentContext.empty());
517532
}
518533

519534
@Override
@@ -546,7 +561,7 @@ public void logModel(@NonNull String modelName, byte[] data, @NonNull String log
546561
@Override
547562
public void logModel(@NonNull String modelName, byte[] data, @NonNull String logicalPath, boolean overwrite,
548563
Map<String, Object> metadata) {
549-
this.logModel(modelName, data, logicalPath, overwrite, metadata, this.baseContext);
564+
this.logModel(modelName, data, logicalPath, overwrite, metadata, ExperimentContext.empty());
550565
}
551566

552567
@Override

comet-java-client/src/test/java/ml/comet/experiment/impl/LogAssetsSupportTest.java

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ public void testLogAndGetRemoteAssets() {
4949

5050
String secondAssetExpectedFileName = "secondAssetFile.extension";
5151
URI secondAssetLink = new URI("s3://bucket/folder/" + secondAssetExpectedFileName);
52+
experiment.setContext(TestUtils.SOME_CONTEXT_ID);
5253
experiment.logRemoteAsset(secondAssetLink, empty(), false, empty(),
53-
TestUtils.SOME_FULL_CONTEXT, Optional.of(onComplete));
54+
ExperimentContext.empty(), Optional.of(onComplete));
5455

5556
awaitForCondition(onComplete, "second remote asset onComplete timeout", 30);
5657

@@ -59,8 +60,11 @@ public void testLogAndGetRemoteAssets() {
5960
awaitForCondition(() -> experiment.getAllAssetList().size() == 2, "Assets was uploaded");
6061
List<LoggedExperimentAsset> assets = experiment.getAllAssetList();
6162

62-
validateRemoteAssetLink(assets, firstAssetLink, firstAssetFileName, TestUtils.SOME_METADATA);
63-
validateRemoteAssetLink(assets, secondAssetLink, secondAssetExpectedFileName, null);
63+
validateRemoteAssetLink(assets, firstAssetLink, firstAssetFileName, TestUtils.SOME_METADATA,
64+
new ExperimentContext(TestUtils.SOME_FULL_CONTEXT.getStep(), 0,
65+
TestUtils.SOME_FULL_CONTEXT.getContext()));
66+
validateRemoteAssetLink(assets, secondAssetLink, secondAssetExpectedFileName, null,
67+
new ExperimentContext(0, 0, TestUtils.SOME_CONTEXT_ID));
6468
} catch (Exception e) {
6569
fail(e);
6670
}
@@ -82,8 +86,9 @@ public void testLogAndGetAssets() throws Exception {
8286
awaitForCondition(onComplete, "image file onComplete timeout", 30);
8387

8488
onComplete = new OnlineExperimentTest.OnCompleteAction();
89+
experiment.setContext(TestUtils.SOME_CONTEXT_ID);
8590
experiment.logAssetFileAsync(Objects.requireNonNull(TestUtils.getFile(SOME_TEXT_FILE_NAME)), SOME_TEXT_FILE_NAME,
86-
false, TestUtils.SOME_FULL_CONTEXT, Optional.of(onComplete));
91+
false, ExperimentContext.empty(), Optional.of(onComplete));
8792
awaitForCondition(onComplete, "text file onComplete timeout", 30);
8893

8994
// wait for assets become available and validate results
@@ -92,7 +97,8 @@ public void testLogAndGetAssets() throws Exception {
9297

9398
List<LoggedExperimentAsset> assets = experiment.getAllAssetList();
9499
validateAsset(assets, IMAGE_FILE_NAME, IMAGE_FILE_SIZE, TestUtils.SOME_FULL_CONTEXT);
95-
validateAsset(assets, SOME_TEXT_FILE_NAME, SOME_TEXT_FILE_SIZE, TestUtils.SOME_FULL_CONTEXT);
100+
validateAsset(assets, SOME_TEXT_FILE_NAME, SOME_TEXT_FILE_SIZE,
101+
new ExperimentContext(0, 0, TestUtils.SOME_CONTEXT_ID));
96102

97103
// update one of the assets and validate
98104
//
@@ -155,19 +161,28 @@ public void testLogAndGetAssetsFolder(boolean flatDirectoryStructure, boolean re
155161
}
156162

157163
static void validateRemoteAssetLink(List<LoggedExperimentAsset> assets, URI uri,
158-
String fileName, Map<String, Object> metadata) {
164+
String fileName, Map<String, Object> metadata,
165+
ExperimentContext context) {
159166
if (Objects.nonNull(metadata)) {
160167
assertTrue(assets.stream()
161168
.filter(asset -> Objects.equals(uri, asset.getLink().orElse(null)))
162169
.allMatch(asset -> asset.isRemote()
163170
&& Objects.equals(asset.getLogicalPath(), fileName)
164-
&& Objects.equals(asset.getMetadata(), metadata)));
171+
&& Objects.equals(asset.getMetadata(), metadata)
172+
));
165173
} else {
166174
assertTrue(assets.stream()
167175
.filter(asset -> Objects.equals(uri, asset.getLink().orElse(null)))
168176
.allMatch(asset -> asset.isRemote()
169177
&& Objects.equals(asset.getLogicalPath(), fileName)
170-
&& asset.getMetadata().isEmpty()));
178+
&& asset.getMetadata().isEmpty()
179+
));
171180
}
181+
// check context
182+
assertTrue(assets.stream()
183+
.filter(asset -> Objects.equals(uri, asset.getLink().orElse(null)))
184+
.allMatch(asset -> asset.getExperimentContext().isPresent()
185+
&& Objects.equals(asset.getExperimentContext().get(), context)
186+
));
172187
}
173188
}

0 commit comments

Comments
 (0)