-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfaceEmbedding.ts
More file actions
150 lines (126 loc) · 4.25 KB
/
faceEmbedding.ts
File metadata and controls
150 lines (126 loc) · 4.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import ImageEditor from '@react-native-community/image-editor';
import RNFS from 'react-native-fs';
import { Skia } from '@shopify/react-native-skia';
import type { TensorflowModel } from 'react-native-fast-tflite';
import { Buffer } from 'buffer';
export type FaceFrame = {
top: number;
left: number;
width: number;
height: number;
};
/**
* Crop + resize to 112x112 from a LOCAL file:// image URI.
* Uses @react-native-community/image-editor (replacement for deprecated core ImageEditor).
*/
async function cropFace112(
localFileUri: string,
frame: FaceFrame,
): Promise<string> {
if (!localFileUri.startsWith('file://')) {
throw new Error(
`imageUri must be a local file:// URI. Got: ${localFileUri}`,
);
}
const cropData = {
offset: { x: Math.max(0, frame.left), y: Math.max(0, frame.top) },
size: {
width: Math.max(1, frame.width),
height: Math.max(1, frame.height),
},
displaySize: { width: 112, height: 112 },
resizeMode: 'cover' as const,
};
// ✅ FIX: Modern ImageEditor returns an object { uri: string, width: number, ... }
const result = await ImageEditor.cropImage(localFileUri, cropData);
const croppedUri = typeof result === 'string' ? result : result.uri;
// Ensure the URI has the file:// prefix for RNFS/Skia
if (!croppedUri.startsWith('file://') && !croppedUri.startsWith('http')) {
return `file://${croppedUri}`;
}
return croppedUri;
}
/**
* Decode a 112x112 cropped image into Float32 input tensor [1,112,112,3].
* Assumes croppedUri is file:// so RNFS can read.
*/
async function image112ToInputTensor(
croppedUri: string,
): Promise<Float32Array> {
if (!croppedUri || typeof croppedUri !== 'string') {
throw new Error(
'image112ToInputTensor: croppedUri is undefined or not a string',
);
}
const path = croppedUri.replace('file://', '');
const exists = await RNFS.exists(path);
if (!exists) throw new Error(`File does not exist at path: ${path}`);
const base64 = await RNFS.readFile(path, 'base64');
const bytes = new Uint8Array(Buffer.from(base64, 'base64'));
const skData = Skia.Data.fromBytes(bytes);
const skImage = Skia.Image.MakeImageFromEncoded(skData);
if (!skImage)
throw new Error('Failed to decode cropped face image into Skia.');
// ✅ FIX: Ensure offsets (0,0) and dimensions (112,112) are explicitly passed
// Some Skia versions require the full ImageInfo object
const pixels = skImage.readPixels(0, 0, {
width: 112,
height: 112,
colorType: 4, // RGBA_8888
alphaType: 1, // Opaque
});
if (!pixels)
throw new Error(
'Failed to read pixels. Check if image dimensions are exactly 112x112.',
);
// pixels.length should be 112 * 112 * 4 = 50176
const input = new Float32Array(1 * 112 * 112 * 3);
let j = 0;
for (let i = 0; i < pixels.length; i += 4) {
// Only process if we haven't overfilled the input array
if (j >= input.length) break;
const r = pixels[i + 0];
const g = pixels[i + 1];
const b = pixels[i + 2];
// Normalization: (x - 127.5) / 128.0
input[j++] = (r - 127.5) / 128.0;
input[j++] = (g - 127.5) / 128.0;
input[j++] = (b - 127.5) / 128.0;
}
return input;
}
function l2Normalize(vec: Float32Array): Float32Array {
let sum = 0;
for (let i = 0; i < vec.length; i++) sum += vec[i] * vec[i];
const norm = Math.sqrt(sum) || 1;
const out = new Float32Array(vec.length);
for (let i = 0; i < vec.length; i++) out[i] = vec[i] / norm;
return out;
}
/**
* Main API:
* - local file:// image
* - MLKit bbox frame
* - MobileFaceNet TFLite model (react-native-fast-tflite)
* -> returns L2-normalized embedding (Float32Array)
*/
export async function getMobileFaceNetEmbeddingFromFrame(
localFileUri: string,
frame: FaceFrame,
model: TensorflowModel,
): Promise<Float32Array> {
const cropped112Uri = await cropFace112(localFileUri, frame);
const input = await image112ToInputTensor(cropped112Uri);
const outputs = model.runSync([input]);
const embedding = outputs[0] as Float32Array;
return l2Normalize(embedding);
}
/**
* Optional helper: get the cropped face URI (for saving / debugging).
*/
export async function cropFace112ForDebug(
localFileUri: string,
frame: FaceFrame,
): Promise<string> {
return cropFace112(localFileUri, frame);
}