Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions src/ImageClassifier/darknet.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (c) 2018 ml5
//
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT

import * as tf from '@tensorflow/tfjs';
import IMAGENET_CLASSES_DARKNET from '../utils/IMAGENET_CLASSES_DARKNET';

const DEFAULTS = {
DARKNET_URL: 'https://rawgit.com/TheHidden1/ml5-data-and-models/darknetclassifier/models/darknetclassifier/darknetreference/model.json',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a reminder that we need to change this to the ones hosted in the ml5.js models

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i made a pull request on ml5js/ml5-data-and-models#32
can you merge it ?

DARKNET_TINY_URL: 'https://rawgit.com/TheHidden1/ml5-data-and-models/darknetclassifier/models/darknetclassifier/darknettiny/model.json',
IMAGE_SIZE_DARKNET: 256,
IMAGE_SIZE_DARKNET_TINY: 224,
};


async function getTopKClasses(logits, topK) {
const values = await logits.data();
const valuesAndIndices = [];
for (let i = 0; i < values.length; i += 1) {
valuesAndIndices.push({
value: values[i],
index: i,
});
}
valuesAndIndices.sort((a, b) => b.value - a.value);

const topkValues = new Float32Array(topK);
const topkIndices = new Int32Array(topK);
for (let i = 0; i < topK; i += 1) {
topkValues[i] = valuesAndIndices[i].value;
topkIndices[i] = valuesAndIndices[i].index;
}

const topClassesAndProbs = [];
for (let i = 0; i < topkIndices.length; i += 1) {
topClassesAndProbs.push({
className: IMAGENET_CLASSES_DARKNET[topkIndices[i]],
probability: topkValues[i],
});
}
return topClassesAndProbs;
}

function preProcess(img, size) {
let image;
if (!(img instanceof tf.Tensor)) {
if (img instanceof HTMLImageElement || img instanceof HTMLVideoElement) {
image = tf.fromPixels(img);
} else if (typeof img === 'object' && (img.elt instanceof HTMLImageElement || img.elt instanceof HTMLVideoElement)) {
image = tf.fromPixels(img.elt); // Handle p5.js image and video.
}
} else {
image = img;
}
// Normalize the image from [0, 255] to [0, 1].
const normalized = image.toFloat().div(tf.scalar(255));
let resized = normalized;
if (normalized.shape[0] !== size || normalized.shape[1] !== size) {
const alignCorners = true;
resized = tf.image.resizeBilinear(normalized, [size, size], alignCorners);
}
// Reshape to a single-element batch so we can pass it to predict.
const batched = resized.reshape([1, size, size, 3]);
// Scale Stuff
// this.scaleX = this.imgHeight / this.inputHeight;
// this.scaleY = this.imgWidth / this.inputWidth;
return batched;
}
export class Darknet {
constructor(version) {
this.version = version;
switch (this.version) {
case 'reference':
this.imgSize = DEFAULTS.IMAGE_SIZE_DARKNET;
break;
case 'tiny':
this.imgSize = DEFAULTS.IMAGE_SIZE_DARKNET_TINY;
break;

default:
break;
}
}

async load() {
switch (this.version) {
// might add darknet_448
case 'reference':
this.model = await tf.loadModel(DEFAULTS.DARKNET_URL);
break;

case 'tiny':
this.model = await tf.loadModel(DEFAULTS.DARKNET_TINY_URL);
break;

default:
break;
}

// Warmup the model.
const result = tf.tidy(() => this.model.predict(tf.zeros([1, this.imgSize, this.imgSize, 3])));
await result.data();
result.dispose();
}

/**
* Classifies an image from the 1000 ImageNet classes returning a map of
* the most likely class names to their probability.
*
* @param img The image to classify. Can be a tensor or a DOM element image,
* video, or canvas.
* @param topk How many top values to use. Defaults to 3.
*/
async classify(img, topk = 3) {
const logits = tf.tidy(() => {
const imgData = preProcess(img, this.imgSize);
const predictions = this.model.predict(imgData);
return tf.softmax(predictions);
});
const classes = await getTopKClasses(logits, topk);
logits.dispose();
return classes;
}
}


export async function load(version) {
if (tf == null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we need this validation here. TF.js comes with ml5.js

throw new Error('Cannot find TensorFlow.js. If you are using a <script> tag, please ' +
'also include @tensorflow/tfjs on the page before using this model.');
}
if (version !== 'reference' && version !== 'tiny') {
throw new Error('Please select a version : darknet-reference or darknet-tiny');
}

const darknet = new Darknet(version);
await darknet.load();
return darknet;
}
26 changes: 19 additions & 7 deletions src/ImageClassifier/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ Image Classifier using pre-trained networks

import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';
import * as darknet from './darknet';
import callCallback from '../utils/callcallback';


const DEFAULTS = {
mobilenet: {
version: 1,
Expand All @@ -23,14 +25,24 @@ class ImageClassifier {
constructor(modelName, video, options, callback) {
this.modelName = modelName;
this.video = video;
this.version = options.version || DEFAULTS[this.modelName].version;
this.alpha = options.alpha || DEFAULTS[this.modelName].alpha;
this.topk = options.topk || DEFAULTS[this.modelName].topk;
this.model = null;
if (this.modelName === 'mobilenet') {
this.modelToUse = mobilenet;
} else {
this.modelToUse = null;
switch (this.modelName) {
case 'mobilenet':
this.modelToUse = mobilenet;
this.version = options.version || DEFAULTS.mobilenet.version;
this.alpha = options.alpha || DEFAULTS.mobilenet.alpha;
this.topk = options.topk || DEFAULTS.mobilenet.topk;
break;
case 'darknet':
this.version = 'reference'; // this a 28mb model
this.modelToUse = darknet;
break;
case 'darknet-tiny':
this.version = 'tiny'; // this a 4mb model
this.modelToUse = darknet;
break;
default:
this.modelToUse = null;
}
// Load the model
this.ready = callCallback(this.loadModel(), callback);
Expand Down
Loading