@@ -13,14 +13,14 @@ import UIKit
1313import os. log
1414
1515public enum MobileNetClassifierError : Error {
16- case inputPointer
16+ case inference
1717 case rawData
1818 case transform
1919
2020 var localizedDescription : String {
2121 switch self {
22- case . inputPointer :
23- return " Cannot get the input pointer base address "
22+ case . inference :
23+ return " Cannot recognize the image "
2424 case . rawData:
2525 return " Cannot get the pixel data from the image "
2626 case . transform:
@@ -35,15 +35,15 @@ public class MobileNetClassifier: ImageClassification {
3535 private static let resizeSize : CGFloat = 256
3636 private static let cropSize : CGFloat = 224
3737
38- private var mobileNetClassifier : ETMobileNetClassifier
38+ private var module : Module
3939 private var labels : [ String ] = [ ]
4040 private var rawDataBuffer : [ UInt8 ]
4141 private var normalizedBuffer : [ Float ]
4242
4343 public init ? ( modelFilePath: String , labelsFilePath: String ) throws {
4444 labels = try String ( contentsOfFile: labelsFilePath, encoding: . utf8)
4545 . components ( separatedBy: . newlines)
46- mobileNetClassifier = ETMobileNetClassifier ( filePath: modelFilePath)
46+ module = Module ( filePath: modelFilePath)
4747 rawDataBuffer = [ UInt8] ( repeating: 0 , count: Int ( Self . cropSize * Self. cropSize) * 4 )
4848 normalizedBuffer = [ Float] ( repeating: 0 , count: rawDataBuffer. count / 4 * 3 )
4949
@@ -59,22 +59,18 @@ public class MobileNetClassifier: ImageClassification {
5959 }
6060
6161 public func classify( image: UIImage ) throws -> [ Classification ] {
62- var input = try normalize ( rawData ( from: transformed ( image) ) )
63- var output = [ Float] ( repeating: 0 , count: labels. count)
64-
65- try input. withUnsafeMutableBufferPointer { inputPointer in
66- guard let inputPointerBaseAddress = inputPointer. baseAddress else {
67- throw MobileNetClassifierError . inputPointer
68- }
69- try mobileNetClassifier. classify (
70- withInput: inputPointerBaseAddress,
71- output: & output,
72- outputSize: labels. count)
62+ let input = try normalize ( rawData ( from: transformed ( image) ) ) . withUnsafeBytes {
63+ Tensor ( bytes: $0. baseAddress!, shape: [ 1 , 3 , 224 , 224 ] , dataType: . float)
64+ }
65+ guard let output = try module. forward ( input) . first? . tensor? . withUnsafeBytes ( [ Float ] . init)
66+ else {
67+ throw MobileNetClassifierError . inference
7368 }
74- return softmax ( output) . enumerated ( ) . sorted ( by: { $0. element > $1. element } )
75- . compactMap { ( index, probability) -> Classification ? in
76- guard index < labels. count else { return nil }
77- return Classification ( label: labels [ index] , confidence: probability)
69+ return softmax ( output)
70+ . enumerated ( )
71+ . sorted ( by: { $0. element > $1. element } )
72+ . compactMap { index, probability in
73+ index < labels. count ? Classification ( label: labels [ index] , confidence: probability) : nil
7874 }
7975 }
8076
0 commit comments