@@ -7,6 +7,7 @@ import 'package:flutter/material.dart';
77import 'package:flutter/services.dart' show rootBundle;
88import 'package:grpc/grpc.dart' ;
99import 'package:http/http.dart' as http;
10+ import 'package:protobuf/protobuf.dart' ;
1011
1112import 'proto/generated/tensorflow/core/framework/tensor.pb.dart' ;
1213import 'proto/generated/tensorflow/core/framework/tensor_shape.pb.dart' ;
@@ -64,88 +65,92 @@ class _TFServingDemoState extends State<TFServingDemo> {
6465 colorScheme: ColorScheme .fromSeed (seedColor: Colors .deepPurple),
6566 ),
6667 home: Scaffold (
67- appBar: AppBar (
68- title: const Text ('TF Serving Flutter Demo' ),
69- ),
68+ appBar: AppBar (title: const Text ('TF Serving Flutter Demo' )),
7069 body: Center (
7170 child: Container (
7271 padding: const EdgeInsets .fromLTRB (20 , 30 , 20 , 20 ),
7372 child: Column (
74- mainAxisAlignment: MainAxisAlignment .spaceEvenly,
75- children: [
76- TextField (
77- controller: _inputSentenceController,
78- decoration: const InputDecoration (
79- border: UnderlineInputBorder (),
80- hintText: 'Enter a sentence here' ) ,
73+ mainAxisAlignment: MainAxisAlignment .spaceEvenly,
74+ children: [
75+ TextField (
76+ controller: _inputSentenceController,
77+ decoration: const InputDecoration (
78+ border: UnderlineInputBorder (),
79+ hintText: 'Enter a sentence here' ,
8180 ),
82- Column (
83- children : < Widget > [
84- ListTile (
85- title : const Text ( 'gRPC' ),
86- leading : Radio < ConnectionModeType >(
87- value : ConnectionModeType .grpc,
88- groupValue : _connectionMode ,
89- onChanged : (value) {
90- setState (( ) {
91- _connectionMode = value;
92- }) ;
93- },
94- ) ,
81+ ),
82+ Column (
83+ children : < Widget > [
84+ ListTile (
85+ title : const Text ( 'gRPC' ),
86+ leading : Radio < ConnectionModeType >(
87+ value : ConnectionModeType .grpc ,
88+ groupValue : _connectionMode,
89+ onChanged : (value ) {
90+ setState (() {
91+ _connectionMode = value ;
92+ });
93+ } ,
9594 ),
96- ListTile (
97- title : const Text ( 'REST' ),
98- leading : Radio < ConnectionModeType >(
99- value : ConnectionModeType .rest,
100- groupValue : _connectionMode ,
101- onChanged : (value) {
102- setState (( ) {
103- _connectionMode = value;
104- }) ;
105- },
106- ) ,
95+ ),
96+ ListTile (
97+ title : const Text ( 'REST' ),
98+ leading : Radio < ConnectionModeType >(
99+ value : ConnectionModeType .rest ,
100+ groupValue : _connectionMode,
101+ onChanged : (value ) {
102+ setState (() {
103+ _connectionMode = value ;
104+ });
105+ } ,
107106 ),
108- ],
109- ),
110- Row (
111- mainAxisAlignment: MainAxisAlignment .spaceEvenly,
112- children: [
113- FilledButton (
114- style: FilledButton .styleFrom (
115- textStyle: const TextStyle (fontSize: 18 ),
116- ),
117- onPressed: () {
118- setState (() {
119- _futurePrediction = predict ();
120- });
121- },
122- child: const Text ('Classify' )),
123- FilledButton (
124- style: FilledButton .styleFrom (
125- textStyle: const TextStyle (fontSize: 18 ),
126- ),
127- onPressed: () {
128- setState (() {
129- _futurePrediction =
130- Future <String >.value (initialPrompt);
131- _inputSentenceController.clear ();
132- });
133- },
134- child: const Text ('Reset' ))
135- ]),
136- FutureBuilder <String >(
137- future: _futurePrediction,
138- builder: (context, snapshot) {
139- if (snapshot.hasData) {
140- return Text (snapshot.data! );
141- } else if (snapshot.hasError) {
142- return Text ('${snapshot .error }' );
143- }
144- // By default, show a loading spinner.
145- return const CircularProgressIndicator ();
146- },
147- ),
148- ]),
107+ ),
108+ ],
109+ ),
110+ Row (
111+ mainAxisAlignment: MainAxisAlignment .spaceEvenly,
112+ children: [
113+ FilledButton (
114+ style: FilledButton .styleFrom (
115+ textStyle: const TextStyle (fontSize: 18 ),
116+ ),
117+ onPressed: () {
118+ setState (() {
119+ _futurePrediction = predict ();
120+ });
121+ },
122+ child: const Text ('Classify' ),
123+ ),
124+ FilledButton (
125+ style: FilledButton .styleFrom (
126+ textStyle: const TextStyle (fontSize: 18 ),
127+ ),
128+ onPressed: () {
129+ setState (() {
130+ _futurePrediction = Future <String >.value (
131+ initialPrompt,
132+ );
133+ _inputSentenceController.clear ();
134+ });
135+ },
136+ child: const Text ('Reset' ),
137+ ),
138+ ],
139+ ),
140+ FutureBuilder <String >(
141+ future: _futurePrediction,
142+ builder: (context, snapshot) {
143+ if (snapshot.hasData) {
144+ return Text (snapshot.data! );
145+ } else if (snapshot.hasError) {
146+ return Text ('${snapshot .error }' );
147+ }
148+ // By default, show a loading spinner.
149+ return const CircularProgressIndicator ();
150+ },
151+ ),
152+ ],
153+ ),
149154 ),
150155 ),
151156 ),
@@ -212,34 +217,42 @@ class _TFServingDemoState extends State<TFServingDemo> {
212217 throw Exception ('Error response' );
213218 }
214219 } else {
215- final channel = ClientChannel (_server,
216- port: grpcPort,
217- options:
218- const ChannelOptions (credentials: ChannelCredentials .insecure ()));
219- _stub = PredictionServiceClient (channel,
220- options: CallOptions (timeout: const Duration (seconds: 10 )));
220+ final channel = ClientChannel (
221+ _server,
222+ port: grpcPort,
223+ options: const ChannelOptions (
224+ credentials: ChannelCredentials .insecure (),
225+ ),
226+ );
227+ _stub = PredictionServiceClient (
228+ channel,
229+ options: CallOptions (timeout: const Duration (seconds: 10 )),
230+ );
221231
222232 ModelSpec modelSpec = ModelSpec (
223233 name: 'spam-detection' ,
224234 signatureName: 'serving_default' ,
225235 );
226236
227237 TensorShapeProto_Dim batchDim = TensorShapeProto_Dim (size: Int64 (1 ));
228- TensorShapeProto_Dim inputDim =
229- TensorShapeProto_Dim (size: Int64 (maxSentenceLength));
230- TensorShapeProto inputTensorShape =
231- TensorShapeProto (dim: [batchDim, inputDim]);
238+ TensorShapeProto_Dim inputDim = TensorShapeProto_Dim (
239+ size: Int64 (maxSentenceLength),
240+ );
241+ TensorShapeProto inputTensorShape = TensorShapeProto (
242+ dim: [batchDim, inputDim],
243+ );
232244 TensorProto inputTensor = TensorProto (
233- dtype: DataType .DT_INT32 ,
234- tensorShape: inputTensorShape,
235- intVal: _tokenIndices);
245+ dtype: DataType .DT_INT32 ,
246+ tensorShape: inputTensorShape,
247+ intVal: _tokenIndices,
248+ );
236249
237250 // If you train your own model, make sure to update the input and output
238251 // tensor names.
239252 const inputTensorName = 'input_3' ;
240253 const outputTensorName = 'dense_5' ;
241- PredictRequest request = PredictRequest (
242- modelSpec : modelSpec, inputs: {inputTensorName: inputTensor});
254+ PredictRequest request = PredictRequest (modelSpec : modelSpec)
255+ .. inputs. addAll ( {inputTensorName: inputTensor});
243256
244257 PredictResponse response = await _stub.predict (request);
245258
0 commit comments