Skip to content
This repository was archived by the owner on Apr 4, 2023. It is now read-only.

Commit df29885

Browse files
Add custom (TensorFlow Lite) models support to the ML Kit feature #702
1 parent a72d653 commit df29885

18 files changed

+360
-99
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
cereals
2+
juice
3+
nutella
1.59 MB
Binary file not shown.

demo-ng/app/tabs/mlkit/custommodel/custommodel.component.html

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
<MLKitCustomModel
77
width="100%"
88
height="100%"
9-
confidenceThreshold="0.6"
9+
localModelFile="~/custommodel/inception/inception_v3_quant.tflite"
10+
labelsFile="~/custommodel/inception/inception_labels.txt"
11+
modelInputShape="1, 299, 299, 3"
12+
modelInputType="QUANT"
13+
processEveryNthFrame="30"
14+
maxResults="5"
1015
(scanResult)="onCustomModelResult($event)">
1116
</MLKitCustomModel>
1217

@@ -33,16 +38,14 @@
3338
<Label height="1" marginBottom="1" borderBottomWidth="1" borderColor="rgba(81, 184, 237, 1)"></Label>
3439
</StackLayout>
3540
</GridLayout>
36-
<Label [text]="result" row="0" rowSpan="3" col="0"></Label>
37-
38-
<!--ListView separatorColor="transparent" row="0" rowSpan="3" col="0" colSpan="3" [items]="result" class="m-t-20" backgroundColor="transparent">
41+
<ListView separatorColor="transparent" row="0" rowSpan="3" col="0" colSpan="3" [items]="labels" class="m-t-20" backgroundColor="transparent">
3942
<ng-template let-item="item">
4043
<GridLayout columns="3*, 2*">
4144
<Label col="0" class="mlkit-result" textWrap="true" [text]="item.text"></Label>
4245
<Label col="1" class="mlkit-result" textWrap="true" [text]="item.confidence | number"></Label>
4346
</GridLayout>
4447
</ng-template>
45-
</ListView-->
48+
</ListView>
4649
</GridLayout>
4750

4851
<GridLayout rows="auto" columns="auto, auto" horizontalAlignment="right" class="m-t-4 m-r-8">

demo-ng/app/tabs/mlkit/custommodel/custommodel.component.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ import { AbstractMLKitViewComponent } from "~/tabs/mlkit/abstract.mlkitview.comp
88
templateUrl: "./custommodel.component.html",
99
})
1010
export class CustomModelComponent extends AbstractMLKitViewComponent {
11-
result: any;
11+
labels: Array<{
12+
text: string;
13+
confidence: number;
14+
}>;
1215

1316
onCustomModelResult(scanResult: any): void {
1417
const value: MLKitCustomModelResult = scanResult.value;
15-
this.result = value.result;
18+
this.labels = value.result;
1619
}
1720
}

demo-ng/app/tabs/mlkit/mlkit.component.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,15 @@ export class MLKitComponent {
244244
// cloudModelName: "~/mobilenet_quant_v2_1_0_299",
245245
// cloudModelName: "~/inception_v3_quant",
246246

247-
localModelFile: "~/custommodel/mobilenet/mobilenet_quant_v2_1.0_299.tflite",
248-
labelsFile: "~/custommodel/mobilenet/mobilenet_labels.txt",
247+
// note that there's an issue with this model (making the app crash): "ValueError: Model provided has model identifier 'Mobi', should be 'TFL3'" (reported by https://github.com/EddyVerbruggen/ns-mlkit-tflite-curated/blob/master/scripts/get_model_details.py)
248+
// localModelFile: "~/custommodel/nutella/retrained_quantized_model.tflite",
249+
// labelsFile: "~/custommodel/nutella/nutella_labels.txt",
249250

250-
// localModelFile: "~/custommodel/inception/inception_v3_quant.tflite",
251-
// labelsFile: "~/custommodel/inception/inception_labels.txt",
251+
// localModelFile: "~/custommodel/mobilenet/mobilenet_quant_v2_1.0_299.tflite",
252+
// labelsFile: "~/custommodel/mobilenet/mobilenet_labels.txt",
253+
254+
localModelFile: "~/custommodel/inception/inception_v3_quant.tflite",
255+
labelsFile: "~/custommodel/inception/inception_labels.txt",
252256

253257
maxResults: 5,
254258
modelInput: [{

demo-ng/app/tabs/tabs-routing.module.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ import { TextRecognitionComponent } from "~/tabs/mlkit/textrecognition/textrecog
77
import { BarcodeScanningComponent } from "~/tabs/mlkit/barcodescanning/barcodescanning.component";
88
import { FaceDetectionComponent } from "~/tabs/mlkit/facedetection/facedetection.component";
99
import { ImageLabelingComponent } from "~/tabs/mlkit/imagelabeling/imagelabeling.component";
10+
import { CustomModelComponent } from "~/tabs/mlkit/custommodel/custommodel.component";
1011

1112
const routes: Routes = [
1213
{ path: "", component: TabsComponent },
1314
{ path: "mlkit/textrecognition", component: TextRecognitionComponent },
1415
{ path: "mlkit/barcodescanning", component: BarcodeScanningComponent },
1516
{ path: "mlkit/facedetection", component: FaceDetectionComponent },
16-
{ path: "mlkit/imagelabeling", component: ImageLabelingComponent }
17+
{ path: "mlkit/imagelabeling", component: ImageLabelingComponent },
18+
{ path: "mlkit/custommodel", component: CustomModelComponent }
1719
];
1820

1921
@NgModule({

demo-ng/app/tabs/tabs.module.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ import { TextRecognitionComponent } from "~/tabs/mlkit/textrecognition/textrecog
1010
import { BarcodeScanningComponent } from "~/tabs/mlkit/barcodescanning/barcodescanning.component";
1111
import { FaceDetectionComponent } from "~/tabs/mlkit/facedetection/facedetection.component";
1212
import { ImageLabelingComponent } from "~/tabs/mlkit/imagelabeling/imagelabeling.component";
13+
import { CustomModelComponent } from "~/tabs/mlkit/custommodel/custommodel.component";
1314

1415
import { registerElement } from "nativescript-angular/element-registry";
1516
registerElement("MLKitBarcodeScanner", () => require("nativescript-plugin-firebase/mlkit/barcodescanning").MLKitBarcodeScanner);
1617
registerElement("MLKitFaceDetection", () => require("nativescript-plugin-firebase/mlkit/facedetection").MLKitFaceDetection);
1718
registerElement("MLKitTextRecognition", () => require("nativescript-plugin-firebase/mlkit/textrecognition").MLKitTextRecognition);
1819
registerElement("MLKitImageLabeling", () => require("nativescript-plugin-firebase/mlkit/imagelabeling").MLKitImageLabeling);
20+
registerElement("MLKitCustomModel", () => require("nativescript-plugin-firebase/mlkit/custommodel").MLKitCustomModel);
1921

2022
@NgModule({
2123
imports: [
@@ -29,7 +31,8 @@ registerElement("MLKitImageLabeling", () => require("nativescript-plugin-firebas
2931
ImageLabelingComponent,
3032
MLKitComponent,
3133
TabsComponent,
32-
TextRecognitionComponent
34+
TextRecognitionComponent,
35+
CustomModelComponent
3336
],
3437
schemas: [
3538
NO_ERRORS_SCHEMA

src/mlkit/custommodel/custommodel-common.ts

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,83 @@
11
import * as fs from "tns-core-modules/file-system";
2+
import { Property } from "tns-core-modules/ui/core/properties";
23
import { MLKitCameraView } from "../mlkit-cameraview";
4+
import { MLKitCustomModelType } from "./index";
5+
6+
export const localModelFileProperty = new Property<MLKitCustomModel, string>({
7+
name: "localModelFile",
8+
defaultValue: null,
9+
});
10+
11+
export const labelsFileProperty = new Property<MLKitCustomModel, string>({
12+
name: "labelsFile",
13+
defaultValue: null,
14+
});
15+
16+
export const modelInputShapeProperty = new Property<MLKitCustomModel, string>({
17+
name: "modelInputShape",
18+
defaultValue: null,
19+
});
20+
21+
export const modelInputTypeProperty = new Property<MLKitCustomModel, string>({
22+
name: "modelInputType",
23+
defaultValue: null,
24+
});
25+
26+
// TODO could combine this with 'confidenceThreshold'
27+
export const maxResultsProperty = new Property<MLKitCustomModel, number>({
28+
name: "maxResults",
29+
defaultValue: 5
30+
});
331

432
export abstract class MLKitCustomModel extends MLKitCameraView {
533
static scanResultEvent: string = "scanResult";
34+
protected localModelFile: string;
35+
protected labelsFile: string;
36+
protected maxResults: number;
37+
protected modelInputShape: Array<number>;
38+
protected modelInputType: MLKitCustomModelType;
39+
40+
protected onSuccessListener;
41+
protected detectorBusy: boolean;
42+
43+
protected labels: Array<string>;
44+
45+
[localModelFileProperty.setNative](value: string) {
46+
this.localModelFile = value;
47+
}
48+
49+
[labelsFileProperty.setNative](value: string) {
50+
this.labelsFile = value;
51+
if (value.indexOf("~/") === 0) {
52+
this.labels = getLabelsFromAppFolder(value);
53+
} else {
54+
// no dice loading from assets yet, let's advice users to use ~/ for now
55+
console.log("For the 'labelsFile' property, use the ~/ prefix for now..");
56+
return;
57+
}
58+
}
59+
60+
[maxResultsProperty.setNative](value: any) {
61+
this.maxResults = parseInt(value);
62+
}
63+
64+
[modelInputShapeProperty.setNative](value: string) {
65+
if ((typeof value) === "string") {
66+
this.modelInputShape = value.split(",").map(v => parseInt(v.trim()));
67+
}
68+
}
69+
70+
[modelInputTypeProperty.setNative](value: MLKitCustomModelType) {
71+
this.modelInputType = value;
72+
}
673
}
774

75+
localModelFileProperty.register(MLKitCustomModel);
76+
labelsFileProperty.register(MLKitCustomModel);
77+
maxResultsProperty.register(MLKitCustomModel);
78+
modelInputShapeProperty.register(MLKitCustomModel);
79+
modelInputTypeProperty.register(MLKitCustomModel);
80+
881
export function getLabelsFromAppFolder(labelsFile: string): Array<string> {
982
const labelsPath = fs.knownFolders.currentApp().path + labelsFile.substring(1);
1083
return getLabelsFromFile(labelsPath);

src/mlkit/custommodel/index.android.ts

Lines changed: 81 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,122 @@
1+
import * as fs from "tns-core-modules/file-system";
12
import { ImageSource } from "tns-core-modules/image-source";
2-
import { MLKitOptions } from "../";
33
import { MLKitCustomModelOptions, MLKitCustomModelResult, MLKitCustomModelResultValue } from "./";
44
import { getLabelsFromAppFolder, MLKitCustomModel as MLKitCustomModelBase } from "./custommodel-common";
5-
import * as fs from "tns-core-modules/file-system";
65

76
declare const com: any;
87
declare const org: any; // TODO remove after regenerating typings
98

109
export class MLKitCustomModel extends MLKitCustomModelBase {
10+
private detector;
11+
private onFailureListener;
12+
private inputOutputOptions;
1113

1214
protected createDetector(): any {
13-
return getInterpreter(null); // TODO
15+
this.detector = getInterpreter(this.localModelFile);
16+
return this.detector;
17+
}
18+
19+
protected runDetector(imageByteBuffer, previewWidth, previewHeight): void {
20+
if (this.detectorBusy) {
21+
return;
22+
}
23+
24+
this.detectorBusy = true;
25+
26+
if (!this.onFailureListener) {
27+
this.onFailureListener = new com.google.android.gms.tasks.OnFailureListener({
28+
onFailure: exception => {
29+
console.log(exception.getMessage());
30+
this.detectorBusy = false;
31+
}
32+
});
33+
}
34+
35+
const modelExpectsWidth = this.modelInputShape[1];
36+
const modelExpectsHeight = this.modelInputShape[2];
37+
const isQuantized = this.modelInputType !== "FLOAT32";
38+
39+
if (!this.inputOutputOptions) {
40+
let intArrayIn = Array.create("int", 4);
41+
intArrayIn[0] = this.modelInputShape[0];
42+
intArrayIn[1] = modelExpectsWidth;
43+
intArrayIn[2] = modelExpectsHeight;
44+
intArrayIn[3] = this.modelInputShape[3];
45+
46+
const inputType = isQuantized ? com.google.firebase.ml.custom.FirebaseModelDataType.BYTE : com.google.firebase.ml.custom.FirebaseModelDataType.FLOAT32;
47+
48+
let intArrayOut = Array.create("int", 2);
49+
intArrayOut[0] = 1;
50+
intArrayOut[1] = this.labels.length;
51+
52+
this.inputOutputOptions = new com.google.firebase.ml.custom.FirebaseModelInputOutputOptions.Builder()
53+
.setInputFormat(0, inputType, intArrayIn)
54+
.setOutputFormat(0, inputType, intArrayOut)
55+
.build();
56+
}
57+
58+
const input = org.nativescript.plugins.firebase.mlkit.BitmapUtil.byteBufferToByteBuffer(imageByteBuffer, previewWidth, previewHeight, modelExpectsWidth, modelExpectsHeight, isQuantized);
59+
const inputs = new com.google.firebase.ml.custom.FirebaseModelInputs.Builder()
60+
.add(input) // add as many input arrays as your model requires
61+
.build();
62+
63+
this.detector
64+
.run(inputs, this.inputOutputOptions)
65+
.addOnSuccessListener(this.onSuccessListener)
66+
.addOnFailureListener(this.onFailureListener);
1467
}
1568

1669
protected createSuccessListener(): any {
17-
return new com.google.android.gms.tasks.OnSuccessListener({
18-
onSuccess: labels => {
70+
this.onSuccessListener = new com.google.android.gms.tasks.OnSuccessListener({
71+
onSuccess: output => {
72+
const probabilities: Array<number> = output.getOutput(0)[0];
1973

20-
if (labels.size() === 0) return;
74+
if (this.labels.length !== probabilities.length) {
75+
console.log(`The number of labels (${this.labels.length}) is not equal to the interpretation result (${probabilities.length})!`);
76+
return;
77+
}
2178

2279
const result = <MLKitCustomModelResult>{
23-
result: []
80+
result: getSortedResult(this.labels, probabilities, this.maxResults)
2481
};
2582

26-
// see https://github.com/firebase/quickstart-android/blob/0f4c86877fc5f771cac95797dffa8bd026dd9dc7/mlkit/app/src/main/java/com/google/firebase/samples/apps/mlkit/textrecognition/TextRecognitionProcessor.java#L62
27-
for (let i = 0; i < labels.size(); i++) {
28-
const label = labels.get(i);
29-
result.result.push({
30-
text: label.getLabel(),
31-
confidence: label.getConfidence()
32-
});
33-
}
34-
3583
this.notify({
3684
eventName: MLKitCustomModel.scanResultEvent,
3785
object: this,
3886
value: result
3987
});
88+
89+
this.detectorBusy = false;
4090
}
4191
});
92+
93+
return this.onSuccessListener;
4294
}
4395
}
4496

45-
// TODO should probably cache this
46-
function getInterpreter(options: MLKitCustomModelOptions): any {
97+
function getInterpreter(localModelFile?: string): any {
4798
const firModelOptionsBuilder = new com.google.firebase.ml.custom.FirebaseModelOptions.Builder();
4899

49100
let localModelRegistrationSuccess = false;
50101
let cloudModelRegistrationSuccess = false;
51102
let localModelName;
52103

53-
if (options.localModelFile) {
54-
localModelName = options.localModelFile.lastIndexOf("/") === -1 ? options.localModelFile : options.localModelFile.substring(options.localModelFile.lastIndexOf("/") + 1);
104+
if (localModelFile) {
105+
localModelName = localModelFile.lastIndexOf("/") === -1 ? localModelFile : localModelFile.substring(localModelFile.lastIndexOf("/") + 1);
55106

56107
if (com.google.firebase.ml.custom.FirebaseModelManager.getInstance().getLocalModelSource(localModelName)) {
57108
localModelRegistrationSuccess = true;
58109
firModelOptionsBuilder.setLocalModelName(localModelName)
59110
} else {
60-
console.log("model not yet loaded: " + options.localModelFile);
111+
console.log("model not yet loaded: " + localModelFile);
61112

62113
const firModelLocalSourceBuilder = new com.google.firebase.ml.custom.model.FirebaseLocalModelSource.Builder(localModelName);
63114

64-
if (options.localModelFile.indexOf("~/") === 0) {
65-
firModelLocalSourceBuilder.setFilePath(fs.knownFolders.currentApp().path + options.localModelFile.substring(1));
115+
if (localModelFile.indexOf("~/") === 0) {
116+
firModelLocalSourceBuilder.setFilePath(fs.knownFolders.currentApp().path + localModelFile.substring(1));
66117
} else {
67118
// note that this doesn't seem to work, let's advice users to use ~/ for now
68-
firModelLocalSourceBuilder.setAssetFilePath(options.localModelFile);
119+
firModelLocalSourceBuilder.setAssetFilePath(localModelFile);
69120
}
70121

71122
localModelRegistrationSuccess = com.google.firebase.ml.custom.FirebaseModelManager.getInstance().registerLocalModelSource(firModelLocalSourceBuilder.build());
@@ -91,7 +142,7 @@ function getInterpreter(options: MLKitCustomModelOptions): any {
91142
export function useCustomModel(options: MLKitCustomModelOptions): Promise<MLKitCustomModelResult> {
92143
return new Promise((resolve, reject) => {
93144
try {
94-
const interpreter = getInterpreter(options);
145+
const interpreter = getInterpreter(options.localModelFile);
95146

96147
let labels: Array<string>;
97148
if (options.labelsFile.indexOf("~/") === 0) {
@@ -130,7 +181,8 @@ export function useCustomModel(options: MLKitCustomModelOptions): Promise<MLKitC
130181
intArrayIn[2] = options.modelInput[0].shape[2];
131182
intArrayIn[3] = options.modelInput[0].shape[3];
132183

133-
const inputType = options.modelInput[0].type === "FLOAT32" ? com.google.firebase.ml.custom.FirebaseModelDataType.FLOAT32 : com.google.firebase.ml.custom.FirebaseModelDataType.BYTE;
184+
const isQuantized = options.modelInput[0].type !== "FLOAT32";
185+
const inputType = isQuantized ? com.google.firebase.ml.custom.FirebaseModelDataType.BYTE : com.google.firebase.ml.custom.FirebaseModelDataType.FLOAT32;
134186

135187
let intArrayOut = Array.create("int", 2);
136188
intArrayOut[0] = 1;
@@ -142,9 +194,7 @@ export function useCustomModel(options: MLKitCustomModelOptions): Promise<MLKitC
142194
.build();
143195

144196
const image: android.graphics.Bitmap = options.image instanceof ImageSource ? options.image.android : options.image.imageSource.android;
145-
146-
const input = org.nativescript.plugins.firebase.mlkit.BitmapUtil.bitmapToByteBuffer(image, options.modelInput[0].shape[1], options.modelInput[0].shape[2]);
147-
197+
const input = org.nativescript.plugins.firebase.mlkit.BitmapUtil.bitmapToByteBuffer(image, options.modelInput[0].shape[1], options.modelInput[0].shape[2], isQuantized);
148198
const inputs = new com.google.firebase.ml.custom.FirebaseModelInputs.Builder()
149199
.add(input) // add as many input arrays as your model requires
150200
.build();
@@ -161,16 +211,11 @@ export function useCustomModel(options: MLKitCustomModelOptions): Promise<MLKitC
161211
});
162212
}
163213

164-
function getImage(options: MLKitOptions): any /* com.google.firebase.ml.vision.common.FirebaseVisionImage */ {
165-
const image: android.graphics.Bitmap = options.image instanceof ImageSource ? options.image.android : options.image.imageSource.android;
166-
return com.google.firebase.ml.vision.common.FirebaseVisionImage.fromBitmap(image);
167-
}
168-
169-
function getSortedResult(labels: Array<string>, probabilities: Array<number>, maxResults?: number): Array<MLKitCustomModelResultValue> {
214+
function getSortedResult(labels: Array<string>, probabilities: Array<number>, maxResults = 5): Array<MLKitCustomModelResultValue> {
170215
const result: Array<MLKitCustomModelResultValue> = [];
171216
labels.forEach((text, i) => result.push({text, confidence: probabilities[i]}));
172217
result.sort((a, b) => a.confidence < b.confidence ? 1 : (a.confidence === b.confidence ? 0 : -1));
173-
if (maxResults && result.length > maxResults) {
218+
if (result.length > maxResults) {
174219
result.splice(maxResults);
175220
}
176221
result.map(r => r.confidence = (r.confidence & 0xff) / 255.0);

0 commit comments

Comments
 (0)