@@ -58,6 +58,7 @@ import { getDetectedAction } from "./utils/prediction";
5858
5959export enum BroadcastChannelMessages {
6060 RELOAD_PROJECT = "reload-project" ,
61+ REMOVE_MODEL = "remove-model" ,
6162}
6263const broadcastChannel = new BroadcastChannel ( "ml" ) ;
6364
@@ -210,6 +211,7 @@ export interface Actions {
210211 trainModelFlowStart : ( callback ?: ( ) => void ) => Promise < void > ;
211212 closeTrainModelDialogs : ( ) => void ;
212213 trainModel ( ) : Promise < boolean > ;
214+ removeModel ( ) : void ;
213215 setSettings ( update : Partial < Settings > ) : Promise < void > ;
214216 setLanguage ( languageId : string ) : Promise < void > ;
215217
@@ -875,6 +877,12 @@ const createMlStore = (logging: Logging) => {
875877 return ! trainingResult . error ;
876878 } ,
877879
880+ removeModel ( ) : void {
881+ set ( {
882+ model : undefined ,
883+ } ) ;
884+ } ,
885+
878886 async resetProject ( ) : Promise < void > {
879887 const {
880888 project : previousProject ,
@@ -1416,28 +1424,34 @@ useStore.setState(
14161424 "setDataWindow"
14171425) ;
14181426
1419- tf . loadLayersModel ( modelUrl )
1420- . then ( ( model ) => {
1427+ const loadModelFromStorage = async ( ) => {
1428+ try {
1429+ const model = await tf . loadLayersModel ( modelUrl ) ;
14211430 if ( model ) {
14221431 useStore . setState ( { model } , false , "loadModel" ) ;
14231432 }
1424- } )
1425- . catch ( ( ) => {
1433+ } catch ( err ) {
14261434 // This happens if there's no model.
1427- } ) ;
1435+ }
1436+ } ;
14281437
1429- useStore . subscribe ( ( state , prevState ) => {
1438+ useStore . subscribe ( async ( state , prevState ) => {
14301439 const { model : newModel } = state ;
14311440 const { model : previousModel } = prevState ;
14321441 if ( newModel !== previousModel ) {
14331442 if ( ! newModel ) {
1434- tf . io . removeModel ( modelUrl ) . catch ( ( ) => {
1435- // No IndexedDB/no model.
1436- } ) ;
1443+ try {
1444+ await tf . io . removeModel ( modelUrl ) ;
1445+ broadcastChannel . postMessage ( BroadcastChannelMessages . REMOVE_MODEL ) ;
1446+ } catch ( err ) {
1447+ // IndexedDB not available?
1448+ }
14371449 } else {
1438- newModel . save ( modelUrl ) . catch ( ( ) => {
1450+ try {
1451+ await newModel . save ( modelUrl ) ;
1452+ } catch ( err ) {
14391453 // IndexedDB not available?
1440- } ) ;
1454+ }
14411455 }
14421456 }
14431457} ) ;
@@ -1564,11 +1578,19 @@ const storageWithErrHandling = async <T>(
15641578export const loadProjectFromStorage = async ( ) => {
15651579 const loadProjectFromStorage = useStore . getState ( ) . loadProjectFromStorage ;
15661580 await loadProjectFromStorage ( ) ;
1581+ await loadModelFromStorage ( ) ;
15671582 return true ;
15681583} ;
15691584
15701585broadcastChannel . onmessage = async ( event ) => {
1571- if ( event . data === BroadcastChannelMessages . RELOAD_PROJECT ) {
1572- await loadProjectFromStorage ( ) ;
1586+ switch ( event . data ) {
1587+ case BroadcastChannelMessages . RELOAD_PROJECT : {
1588+ await loadProjectFromStorage ( ) ;
1589+ break ;
1590+ }
1591+ case BroadcastChannelMessages . REMOVE_MODEL : {
1592+ useStore . getState ( ) . removeModel ( ) ;
1593+ break ;
1594+ }
15731595 }
15741596} ;
0 commit comments