1+ package mariannelinhares .mnistandroid ;
2+
3+ /*
4+ Copyright 2016 Narrative Nights Inc. All Rights Reserved.
5+
6+ Licensed under the Apache License, Version 2.0 (the "License");
7+ you may not use this file except in compliance with the License.
8+ You may obtain a copy of the License at
9+
10+ http://www.apache.org/licenses/LICENSE-2.0
11+
12+ Unless required by applicable law or agreed to in writing, software
13+ distributed under the License is distributed on an "AS IS" BASIS,
14+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+ See the License for the specific language governing permissions and
16+ limitations under the License.
17+
18+ From: https://raw.githubusercontent
19+ .com/miyosuda/TensorFlowAndroidMNIST/master/app/src/main/java/jp/narr/tensorflowmnist
20+ /DrawModel.java
21+ */
22+
23+ //An activity is a single, focused thing that the user can do. Almost all activities interact with the user,
24+ //so the Activity class takes care of creating a window for you in which you can place your UI with setContentView(View)
25+ import android .app .Activity ;
26+ //PointF holds two float coordinates
27+ import android .graphics .PointF ;
28+ //A mapping from String keys to various Parcelable values (interface for data container values, parcels)
29+ import android .os .Bundle ;
30+ //Object used to report movement (mouse, pen, finger, trackball) events.
31+ // //Motion events may hold either absolute or relative movements and other data, depending on the type of device.
32+ import android .view .MotionEvent ;
33+ //This class represents the basic building block for user interface components.
34+ // A View occupies a rectangular area on the screen and is responsible for drawing
35+ import android .view .View ;
36+ //A user interface element the user can tap or click to perform an action.
37+ import android .widget .Button ;
38+ //A user interface element that displays text to the user. To provide user-editable text, see EditText.
39+ import android .widget .TextView ;
40+ //Resizable-array implementation of the List interface. Implements all optional list operations, and permits all elements,
41+ // including null. In addition to implementing the List interface, this class provides methods to
42+ // //manipulate the size of the array that is used internally to store the list.
43+ import java .util .ArrayList ;
44+ // basic list
45+ import java .util .List ;
46+ //encapsulates a classified image
47+ //public interface to the classification class, exposing a name and the recognize function
48+ import mariannelinhares .mnistandroid .models .Classifier ;
49+ //contains logic for reading labels, creating classifier, and classifying
50+ import mariannelinhares .mnistandroid .models .TensorFlowClassifier ;
51+ //class for drawing MNIST digits by finger
52+ import mariannelinhares .mnistandroid .views .DrawModel ;
53+ //class for drawing the entire app
54+ import mariannelinhares .mnistandroid .views .DrawView ;
55+
56+ public class MainActivity extends Activity implements View .OnClickListener , View .OnTouchListener {
57+
58+ private static final int PIXEL_WIDTH = 28 ;
59+
60+ // ui elements
61+ private Button clearBtn , classBtn ;
62+ private TextView resText ;
63+ private List <Classifier > mClassifiers = new ArrayList <>();
64+
65+ // views
66+ private DrawModel drawModel ;
67+ private DrawView drawView ;
68+ private PointF mTmpPiont = new PointF ();
69+
70+ private float mLastX ;
71+ private float mLastY ;
72+
73+ @ Override
74+ // In the onCreate() method, you perform basic application startup logic that should happen
75+ //only once for the entire life of the activity.
76+ protected void onCreate (Bundle savedInstanceState ) {
77+ //initialization
78+ super .onCreate (savedInstanceState );
79+ setContentView (R .layout .activity_main );
80+
81+ //get drawing view from XML (where the finger writes the number)
82+ drawView = (DrawView ) findViewById (R .id .draw );
83+ //get the model object
84+ drawModel = new DrawModel (PIXEL_WIDTH , PIXEL_WIDTH );
85+
86+ //init the view with the model object
87+ drawView .setModel (drawModel );
88+ // give it a touch listener to activate when the user taps
89+ drawView .setOnTouchListener (this );
90+
91+ //clear button
92+ //clear the drawing when the user taps
93+ clearBtn = (Button ) findViewById (R .id .btn_clear );
94+ clearBtn .setOnClickListener (this );
95+
96+ //class button
97+ //when tapped, this performs classification on the drawn image
98+ classBtn = (Button ) findViewById (R .id .btn_class );
99+ classBtn .setOnClickListener (this );
100+
101+ // res text
102+ //this is the text that shows the output of the classification
103+ resText = (TextView ) findViewById (R .id .tfRes );
104+
105+ // tensorflow
106+ //load up our saved model to perform inference from local storage
107+ loadModel ();
108+ }
109+
110+ //the activity lifecycle
111+
112+ @ Override
113+ //OnResume() is called when the user resumes his Activity which he left a while ago,
114+ // //say he presses home button and then comes back to app, onResume() is called.
115+ protected void onResume () {
116+ drawView .onResume ();
117+ super .onResume ();
118+ }
119+
120+ @ Override
121+ //OnPause() is called when the user receives an event like a call or a text message,
122+ // //when onPause() is called the Activity may be partially or completely hidden.
123+ protected void onPause () {
124+ drawView .onPause ();
125+ super .onPause ();
126+ }
127+ //creates a model object in memory using the saved tensorflow protobuf model file
128+ //which contains all the learned weights
129+ private void loadModel () {
130+ //The Runnable interface is another way in which you can implement multi-threading other than extending the
131+ // //Thread class due to the fact that Java allows you to extend only one class. Runnable is just an interface,
132+ // //which provides the method run.
133+ // //Threads are implementations and use Runnable to call the method run().
134+ new Thread (new Runnable () {
135+ @ Override
136+ public void run () {
137+ try {
138+ //add 2 classifiers to our classifier arraylist
139+ //the tensorflow classifier and the keras classifier
140+ mClassifiers .add (
141+ TensorFlowClassifier .create (getAssets (), "TensorFlow" ,
142+ "opt_mnist_convnet-tf.pb" , "labels.txt" , PIXEL_WIDTH ,
143+ "input" , "output" , true ));
144+ mClassifiers .add (
145+ TensorFlowClassifier .create (getAssets (), "Keras" ,
146+ "opt_mnist_convnet-keras.pb" , "labels.txt" , PIXEL_WIDTH ,
147+ "conv2d_1_input" , "dense_2/Softmax" , false ));
148+ } catch (final Exception e ) {
149+ //if they aren't found, throw an error!
150+ throw new RuntimeException ("Error initializing classifiers!" , e );
151+ }
152+ }
153+ }).start ();
154+ }
155+
156+ @ Override
157+ public void onClick (View view ) {
158+ //when the user clicks something
159+ if (view .getId () == R .id .btn_clear ) {
160+ //if its the clear button
161+ //clear the drawing
162+ drawModel .clear ();
163+ drawView .reset ();
164+ drawView .invalidate ();
165+ //empty the text view
166+ resText .setText ("" );
167+ } else if (view .getId () == R .id .btn_class ) {
168+ //if the user clicks the classify button
169+ //get the pixel data and store it in an array
170+ float pixels [] = drawView .getPixelData ();
171+
172+ //init an empty string to fill with the classification output
173+ String text = "" ;
174+ //for each classifier in our array
175+ for (Classifier classifier : mClassifiers ) {
176+ //perform classification on the image
177+ final Classification res = classifier .recognize (pixels );
178+ //if it can't classify, output a question mark
179+ if (res .getLabel () == null ) {
180+ text += classifier .name () + ": ?\n " ;
181+ } else {
182+ //else output its name
183+ text += String .format ("%s: %s, %f\n " , classifier .name (), res .getLabel (),
184+ res .getConf ());
185+ }
186+ }
187+ resText .setText (text );
188+ }
189+ }
190+
191+ @ Override
192+ //this method detects which direction a user is moving
193+ //their finger and draws a line accordingly in that
194+ //direction
195+ public boolean onTouch (View v , MotionEvent event ) {
196+ //get the action and store it as an int
197+ int action = event .getAction () & MotionEvent .ACTION_MASK ;
198+ //actions have predefined ints, lets match
199+ //to detect, if the user has touched, which direction the users finger is
200+ //moving, and if they've stopped moving
201+
202+ //if touched
203+ if (action == MotionEvent .ACTION_DOWN ) {
204+ //begin drawing line
205+ processTouchDown (event );
206+ return true ;
207+ //draw line in every direction the user moves
208+ } else if (action == MotionEvent .ACTION_MOVE ) {
209+ processTouchMove (event );
210+ return true ;
211+ //if finger is lifted, stop drawing
212+ } else if (action == MotionEvent .ACTION_UP ) {
213+ processTouchUp ();
214+ return true ;
215+ }
216+ return false ;
217+ }
218+
219+ //draw line down
220+
221+ private void processTouchDown (MotionEvent event ) {
222+ //calculate the x, y coordinates where the user has touched
223+ mLastX = event .getX ();
224+ mLastY = event .getY ();
225+ //user them to calcualte the position
226+ drawView .calcPos (mLastX , mLastY , mTmpPiont );
227+ //store them in memory to draw a line between the
228+ //difference in positions
229+ float lastConvX = mTmpPiont .x ;
230+ float lastConvY = mTmpPiont .y ;
231+ //and begin the line drawing
232+ drawModel .startLine (lastConvX , lastConvY );
233+ }
234+
235+ //the main drawing function
236+ //it actually stores all the drawing positions
237+ //into the drawmodel object
238+ //we actually render the drawing from that object
239+ //in the drawrenderer class
240+ private void processTouchMove (MotionEvent event ) {
241+ float x = event .getX ();
242+ float y = event .getY ();
243+
244+ drawView .calcPos (x , y , mTmpPiont );
245+ float newConvX = mTmpPiont .x ;
246+ float newConvY = mTmpPiont .y ;
247+ drawModel .addLineElem (newConvX , newConvY );
248+
249+ mLastX = x ;
250+ mLastY = y ;
251+ drawView .invalidate ();
252+ }
253+
254+ private void processTouchUp () {
255+ drawModel .endLine ();
256+ }
257+ }
0 commit comments