Skip to content

Commit cc46c88

Browse files
Sync model state across tabs
1 parent fcc6bfb commit cc46c88

File tree

1 file changed

+35
-13
lines changed

1 file changed

+35
-13
lines changed

src/store.ts

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ import { getDetectedAction } from "./utils/prediction";
5858

5959
export enum BroadcastChannelMessages {
6060
RELOAD_PROJECT = "reload-project",
61+
REMOVE_MODEL = "remove-model",
6162
}
6263
const broadcastChannel = new BroadcastChannel("ml");
6364

@@ -210,6 +211,7 @@ export interface Actions {
210211
trainModelFlowStart: (callback?: () => void) => Promise<void>;
211212
closeTrainModelDialogs: () => void;
212213
trainModel(): Promise<boolean>;
214+
removeModel(): void;
213215
setSettings(update: Partial<Settings>): Promise<void>;
214216
setLanguage(languageId: string): Promise<void>;
215217

@@ -875,6 +877,12 @@ const createMlStore = (logging: Logging) => {
875877
return !trainingResult.error;
876878
},
877879

880+
removeModel(): void {
881+
set({
882+
model: undefined,
883+
});
884+
},
885+
878886
async resetProject(): Promise<void> {
879887
const {
880888
project: previousProject,
@@ -1416,28 +1424,34 @@ useStore.setState(
14161424
"setDataWindow"
14171425
);
14181426

1419-
tf.loadLayersModel(modelUrl)
1420-
.then((model) => {
1427+
const loadModelFromStorage = async () => {
1428+
try {
1429+
const model = await tf.loadLayersModel(modelUrl);
14211430
if (model) {
14221431
useStore.setState({ model }, false, "loadModel");
14231432
}
1424-
})
1425-
.catch(() => {
1433+
} catch (err) {
14261434
// This happens if there's no model.
1427-
});
1435+
}
1436+
};
14281437

1429-
useStore.subscribe((state, prevState) => {
1438+
useStore.subscribe(async (state, prevState) => {
14301439
const { model: newModel } = state;
14311440
const { model: previousModel } = prevState;
14321441
if (newModel !== previousModel) {
14331442
if (!newModel) {
1434-
tf.io.removeModel(modelUrl).catch(() => {
1435-
// No IndexedDB/no model.
1436-
});
1443+
try {
1444+
await tf.io.removeModel(modelUrl);
1445+
broadcastChannel.postMessage(BroadcastChannelMessages.REMOVE_MODEL);
1446+
} catch (err) {
1447+
// IndexedDB not available?
1448+
}
14371449
} else {
1438-
newModel.save(modelUrl).catch(() => {
1450+
try {
1451+
await newModel.save(modelUrl);
1452+
} catch (err) {
14391453
// IndexedDB not available?
1440-
});
1454+
}
14411455
}
14421456
}
14431457
});
@@ -1564,11 +1578,19 @@ const storageWithErrHandling = async <T>(
15641578
export const loadProjectFromStorage = async () => {
15651579
const loadProjectFromStorage = useStore.getState().loadProjectFromStorage;
15661580
await loadProjectFromStorage();
1581+
await loadModelFromStorage();
15671582
return true;
15681583
};
15691584

15701585
broadcastChannel.onmessage = async (event) => {
1571-
if (event.data === BroadcastChannelMessages.RELOAD_PROJECT) {
1572-
await loadProjectFromStorage();
1586+
switch (event.data) {
1587+
case BroadcastChannelMessages.RELOAD_PROJECT: {
1588+
await loadProjectFromStorage();
1589+
break;
1590+
}
1591+
case BroadcastChannelMessages.REMOVE_MODEL: {
1592+
useStore.getState().removeModel();
1593+
break;
1594+
}
15731595
}
15741596
};

0 commit comments

Comments
 (0)