1+ import * as fs from "tns-core-modules/file-system" ;
12import { ImageSource } from "tns-core-modules/image-source" ;
2- import { MLKitOptions } from "../" ;
33import { MLKitCustomModelOptions , MLKitCustomModelResult , MLKitCustomModelResultValue } from "./" ;
44import { getLabelsFromAppFolder , MLKitCustomModel as MLKitCustomModelBase } from "./custommodel-common" ;
5- import * as fs from "tns-core-modules/file-system" ;
65
76declare const com : any ;
87declare const org : any ; // TODO remove after regenerating typings
98
109export class MLKitCustomModel extends MLKitCustomModelBase {
10+ private detector ;
11+ private onFailureListener ;
12+ private inputOutputOptions ;
1113
1214 protected createDetector ( ) : any {
13- return getInterpreter ( null ) ; // TODO
15+ this . detector = getInterpreter ( this . localModelFile ) ;
16+ return this . detector ;
17+ }
18+
19+ protected runDetector ( imageByteBuffer , previewWidth , previewHeight ) : void {
20+ if ( this . detectorBusy ) {
21+ return ;
22+ }
23+
24+ this . detectorBusy = true ;
25+
26+ if ( ! this . onFailureListener ) {
27+ this . onFailureListener = new com . google . android . gms . tasks . OnFailureListener ( {
28+ onFailure : exception => {
29+ console . log ( exception . getMessage ( ) ) ;
30+ this . detectorBusy = false ;
31+ }
32+ } ) ;
33+ }
34+
35+ const modelExpectsWidth = this . modelInputShape [ 1 ] ;
36+ const modelExpectsHeight = this . modelInputShape [ 2 ] ;
37+ const isQuantized = this . modelInputType !== "FLOAT32" ;
38+
39+ if ( ! this . inputOutputOptions ) {
40+ let intArrayIn = Array . create ( "int" , 4 ) ;
41+ intArrayIn [ 0 ] = this . modelInputShape [ 0 ] ;
42+ intArrayIn [ 1 ] = modelExpectsWidth ;
43+ intArrayIn [ 2 ] = modelExpectsHeight ;
44+ intArrayIn [ 3 ] = this . modelInputShape [ 3 ] ;
45+
46+ const inputType = isQuantized ? com . google . firebase . ml . custom . FirebaseModelDataType . BYTE : com . google . firebase . ml . custom . FirebaseModelDataType . FLOAT32 ;
47+
48+ let intArrayOut = Array . create ( "int" , 2 ) ;
49+ intArrayOut [ 0 ] = 1 ;
50+ intArrayOut [ 1 ] = this . labels . length ;
51+
52+ this . inputOutputOptions = new com . google . firebase . ml . custom . FirebaseModelInputOutputOptions . Builder ( )
53+ . setInputFormat ( 0 , inputType , intArrayIn )
54+ . setOutputFormat ( 0 , inputType , intArrayOut )
55+ . build ( ) ;
56+ }
57+
58+ const input = org . nativescript . plugins . firebase . mlkit . BitmapUtil . byteBufferToByteBuffer ( imageByteBuffer , previewWidth , previewHeight , modelExpectsWidth , modelExpectsHeight , isQuantized ) ;
59+ const inputs = new com . google . firebase . ml . custom . FirebaseModelInputs . Builder ( )
60+ . add ( input ) // add as many input arrays as your model requires
61+ . build ( ) ;
62+
63+ this . detector
64+ . run ( inputs , this . inputOutputOptions )
65+ . addOnSuccessListener ( this . onSuccessListener )
66+ . addOnFailureListener ( this . onFailureListener ) ;
1467 }
1568
1669 protected createSuccessListener ( ) : any {
17- return new com . google . android . gms . tasks . OnSuccessListener ( {
18- onSuccess : labels => {
70+ this . onSuccessListener = new com . google . android . gms . tasks . OnSuccessListener ( {
71+ onSuccess : output => {
72+ const probabilities : Array < number > = output . getOutput ( 0 ) [ 0 ] ;
1973
20- if ( labels . size ( ) === 0 ) return ;
74+ if ( this . labels . length !== probabilities . length ) {
75+ console . log ( `The number of labels (${ this . labels . length } ) is not equal to the interpretation result (${ probabilities . length } )!` ) ;
76+ return ;
77+ }
2178
2279 const result = < MLKitCustomModelResult > {
23- result : [ ]
80+ result : getSortedResult ( this . labels , probabilities , this . maxResults )
2481 } ;
2582
26- // see https://github.com/firebase/quickstart-android/blob/0f4c86877fc5f771cac95797dffa8bd026dd9dc7/mlkit/app/src/main/java/com/google/firebase/samples/apps/mlkit/textrecognition/TextRecognitionProcessor.java#L62
27- for ( let i = 0 ; i < labels . size ( ) ; i ++ ) {
28- const label = labels . get ( i ) ;
29- result . result . push ( {
30- text : label . getLabel ( ) ,
31- confidence : label . getConfidence ( )
32- } ) ;
33- }
34-
3583 this . notify ( {
3684 eventName : MLKitCustomModel . scanResultEvent ,
3785 object : this ,
3886 value : result
3987 } ) ;
88+
89+ this . detectorBusy = false ;
4090 }
4191 } ) ;
92+
93+ return this . onSuccessListener ;
4294 }
4395}
4496
45- // TODO should probably cache this
46- function getInterpreter ( options : MLKitCustomModelOptions ) : any {
97+ function getInterpreter ( localModelFile ?: string ) : any {
4798 const firModelOptionsBuilder = new com . google . firebase . ml . custom . FirebaseModelOptions . Builder ( ) ;
4899
49100 let localModelRegistrationSuccess = false ;
50101 let cloudModelRegistrationSuccess = false ;
51102 let localModelName ;
52103
53- if ( options . localModelFile ) {
54- localModelName = options . localModelFile . lastIndexOf ( "/" ) === - 1 ? options . localModelFile : options . localModelFile . substring ( options . localModelFile . lastIndexOf ( "/" ) + 1 ) ;
104+ if ( localModelFile ) {
105+ localModelName = localModelFile . lastIndexOf ( "/" ) === - 1 ? localModelFile : localModelFile . substring ( localModelFile . lastIndexOf ( "/" ) + 1 ) ;
55106
56107 if ( com . google . firebase . ml . custom . FirebaseModelManager . getInstance ( ) . getLocalModelSource ( localModelName ) ) {
57108 localModelRegistrationSuccess = true ;
58109 firModelOptionsBuilder . setLocalModelName ( localModelName )
59110 } else {
60- console . log ( "model not yet loaded: " + options . localModelFile ) ;
111+ console . log ( "model not yet loaded: " + localModelFile ) ;
61112
62113 const firModelLocalSourceBuilder = new com . google . firebase . ml . custom . model . FirebaseLocalModelSource . Builder ( localModelName ) ;
63114
64- if ( options . localModelFile . indexOf ( "~/" ) === 0 ) {
65- firModelLocalSourceBuilder . setFilePath ( fs . knownFolders . currentApp ( ) . path + options . localModelFile . substring ( 1 ) ) ;
115+ if ( localModelFile . indexOf ( "~/" ) === 0 ) {
116+ firModelLocalSourceBuilder . setFilePath ( fs . knownFolders . currentApp ( ) . path + localModelFile . substring ( 1 ) ) ;
66117 } else {
67118 // note that this doesn't seem to work, let's advice users to use ~/ for now
68- firModelLocalSourceBuilder . setAssetFilePath ( options . localModelFile ) ;
119+ firModelLocalSourceBuilder . setAssetFilePath ( localModelFile ) ;
69120 }
70121
71122 localModelRegistrationSuccess = com . google . firebase . ml . custom . FirebaseModelManager . getInstance ( ) . registerLocalModelSource ( firModelLocalSourceBuilder . build ( ) ) ;
@@ -91,7 +142,7 @@ function getInterpreter(options: MLKitCustomModelOptions): any {
91142export function useCustomModel ( options : MLKitCustomModelOptions ) : Promise < MLKitCustomModelResult > {
92143 return new Promise ( ( resolve , reject ) => {
93144 try {
94- const interpreter = getInterpreter ( options ) ;
145+ const interpreter = getInterpreter ( options . localModelFile ) ;
95146
96147 let labels : Array < string > ;
97148 if ( options . labelsFile . indexOf ( "~/" ) === 0 ) {
@@ -130,7 +181,8 @@ export function useCustomModel(options: MLKitCustomModelOptions): Promise<MLKitC
130181 intArrayIn [ 2 ] = options . modelInput [ 0 ] . shape [ 2 ] ;
131182 intArrayIn [ 3 ] = options . modelInput [ 0 ] . shape [ 3 ] ;
132183
133- const inputType = options . modelInput [ 0 ] . type === "FLOAT32" ? com . google . firebase . ml . custom . FirebaseModelDataType . FLOAT32 : com . google . firebase . ml . custom . FirebaseModelDataType . BYTE ;
184+ const isQuantized = options . modelInput [ 0 ] . type !== "FLOAT32" ;
185+ const inputType = isQuantized ? com . google . firebase . ml . custom . FirebaseModelDataType . BYTE : com . google . firebase . ml . custom . FirebaseModelDataType . FLOAT32 ;
134186
135187 let intArrayOut = Array . create ( "int" , 2 ) ;
136188 intArrayOut [ 0 ] = 1 ;
@@ -142,9 +194,7 @@ export function useCustomModel(options: MLKitCustomModelOptions): Promise<MLKitC
142194 . build ( ) ;
143195
144196 const image : android . graphics . Bitmap = options . image instanceof ImageSource ? options . image . android : options . image . imageSource . android ;
145-
146- const input = org . nativescript . plugins . firebase . mlkit . BitmapUtil . bitmapToByteBuffer ( image , options . modelInput [ 0 ] . shape [ 1 ] , options . modelInput [ 0 ] . shape [ 2 ] ) ;
147-
197+ const input = org . nativescript . plugins . firebase . mlkit . BitmapUtil . bitmapToByteBuffer ( image , options . modelInput [ 0 ] . shape [ 1 ] , options . modelInput [ 0 ] . shape [ 2 ] , isQuantized ) ;
148198 const inputs = new com . google . firebase . ml . custom . FirebaseModelInputs . Builder ( )
149199 . add ( input ) // add as many input arrays as your model requires
150200 . build ( ) ;
@@ -161,16 +211,11 @@ export function useCustomModel(options: MLKitCustomModelOptions): Promise<MLKitC
161211 } ) ;
162212}
163213
164- function getImage ( options : MLKitOptions ) : any /* com.google.firebase.ml.vision.common.FirebaseVisionImage */ {
165- const image : android . graphics . Bitmap = options . image instanceof ImageSource ? options . image . android : options . image . imageSource . android ;
166- return com . google . firebase . ml . vision . common . FirebaseVisionImage . fromBitmap ( image ) ;
167- }
168-
169- function getSortedResult ( labels : Array < string > , probabilities : Array < number > , maxResults ?: number ) : Array < MLKitCustomModelResultValue > {
214+ function getSortedResult ( labels : Array < string > , probabilities : Array < number > , maxResults = 5 ) : Array < MLKitCustomModelResultValue > {
170215 const result : Array < MLKitCustomModelResultValue > = [ ] ;
171216 labels . forEach ( ( text , i ) => result . push ( { text, confidence : probabilities [ i ] } ) ) ;
172217 result . sort ( ( a , b ) => a . confidence < b . confidence ? 1 : ( a . confidence === b . confidence ? 0 : - 1 ) ) ;
173- if ( maxResults && result . length > maxResults ) {
218+ if ( result . length > maxResults ) {
174219 result . splice ( maxResults ) ;
175220 }
176221 result . map ( r => r . confidence = ( r . confidence & 0xff ) / 255.0 ) ;
0 commit comments