@@ -498,88 +498,92 @@ steps:
498498 colorScheme: ColorScheme.fromSeed(seedColor: Colors.deepPurple),
499499 ),
500500 home: Scaffold(
501- appBar: AppBar(
502- title: const Text('TF Serving Flutter Demo'),
503- ),
501+ appBar: AppBar(title: const Text('TF Serving Flutter Demo')),
504502 body: Center(
505503 child: Container(
506504 padding: const EdgeInsets.fromLTRB(20, 30, 20, 20),
507505 child: Column(
508- mainAxisAlignment: MainAxisAlignment.spaceEvenly,
509- children: [
510- TextField(
511- controller: _inputSentenceController,
512- decoration: const InputDecoration(
513- border: UnderlineInputBorder(),
514- hintText: 'Enter a sentence here') ,
506+ mainAxisAlignment: MainAxisAlignment.spaceEvenly,
507+ children: [
508+ TextField(
509+ controller: _inputSentenceController,
510+ decoration: const InputDecoration(
511+ border: UnderlineInputBorder(),
512+ hintText: 'Enter a sentence here',
515513 ),
516- Column(
517- children: <Widget>[
518- ListTile(
519- title: const Text('gRPC'),
520- leading: Radio<ConnectionModeType>(
521- value: ConnectionModeType.grpc,
522- groupValue: _connectionMode ,
523- onChanged: (value) {
524- setState(( ) {
525- _connectionMode = value;
526- }) ;
527- },
528- ) ,
514+ ),
515+ Column(
516+ children: <Widget>[
517+ ListTile(
518+ title: const Text('gRPC'),
519+ leading: Radio< ConnectionModeType>(
520+ value: ConnectionModeType.grpc ,
521+ groupValue: _connectionMode,
522+ onChanged: (value ) {
523+ setState(() {
524+ _connectionMode = value ;
525+ });
526+ } ,
529527 ),
530- ListTile(
531- title: const Text('REST'),
532- leading: Radio<ConnectionModeType>(
533- value: ConnectionModeType.rest,
534- groupValue: _connectionMode ,
535- onChanged: (value) {
536- setState(( ) {
537- _connectionMode = value;
538- }) ;
539- },
540- ) ,
528+ ),
529+ ListTile(
530+ title: const Text('REST'),
531+ leading: Radio< ConnectionModeType>(
532+ value: ConnectionModeType.rest ,
533+ groupValue: _connectionMode,
534+ onChanged: (value ) {
535+ setState(() {
536+ _connectionMode = value ;
537+ });
538+ } ,
541539 ),
542- ],
543- ),
544- Row(
545- mainAxisAlignment: MainAxisAlignment.spaceEvenly,
546- children: [
547- FilledButton(
548- style: FilledButton.styleFrom(
549- textStyle: const TextStyle(fontSize: 18),
550- ),
551- onPressed: () {
552- setState(() {
553- _futurePrediction = predict();
554- });
555- },
556- child: const Text('Classify')),
557- FilledButton(
558- style: FilledButton.styleFrom(
559- textStyle: const TextStyle(fontSize: 18),
560- ),
561- onPressed: () {
562- setState(() {
563- _futurePrediction =
564- Future<String>.value(initialPrompt);
565- _inputSentenceController.clear();
566- });
567- },
568- child: const Text('Reset'))
569- ]),
570- FutureBuilder<String>(
571- future: _futurePrediction,
572- builder: (context, snapshot) {
573- if (snapshot.hasData) {
574- return Text(snapshot.data!);
575- } else if (snapshot.hasError) {
576- return Text('${snapshot.error}');
577- }
578- // By default, show a loading spinner.
579- return const CircularProgressIndicator();
580- },
581- ),
582- ]),
540+ ),
541+ ],
542+ ),
543+ Row(
544+ mainAxisAlignment: MainAxisAlignment.spaceEvenly,
545+ children: [
546+ FilledButton(
547+ style: FilledButton.styleFrom(
548+ textStyle: const TextStyle(fontSize: 18),
549+ ),
550+ onPressed: () {
551+ setState(() {
552+ _futurePrediction = predict();
553+ });
554+ },
555+ child: const Text('Classify'),
556+ ),
557+ FilledButton(
558+ style: FilledButton.styleFrom(
559+ textStyle: const TextStyle(fontSize: 18),
560+ ),
561+ onPressed: () {
562+ setState(() {
563+ _futurePrediction = Future<String>.value(
564+ initialPrompt,
565+ );
566+ _inputSentenceController.clear();
567+ });
568+ },
569+ child: const Text('Reset'),
570+ ),
571+ ],
572+ ),
573+ FutureBuilder<String>(
574+ future: _futurePrediction,
575+ builder: (context, snapshot) {
576+ if (snapshot.hasData) {
577+ return Text(snapshot.data!);
578+ } else if (snapshot.hasError) {
579+ return Text('${snapshot.error}');
580+ }
581+ // By default, show a loading spinner.
582+ return const CircularProgressIndicator();
583+ },
584+ ),
585+ ],
586+ ),
583587 ),
584588 ),
585589 ),
@@ -4076,12 +4080,20 @@ steps:
40764080 patch-u : |
40774081 --- b/tfserving-flutter/codelab2/finished/lib/main.dart
40784082 +++ a/tfserving-flutter/codelab2/finished/lib/main.dart
4079- @@ -160,21 +160,99 @@ class _TFServingDemoState extends State<TFServingDemo> {
4083+ @@ -7,6 +7,7 @@ import 'package:flutter/material.dart';
4084+ import 'package:flutter/services.dart' show rootBundle;
4085+ import 'package:grpc/grpc.dart';
4086+ import 'package:http/http.dart' as http;
4087+ +import 'package:protobuf/protobuf.dart';
4088+
4089+ import 'proto/generated/tensorflow/core/framework/tensor.pb.dart';
4090+ import 'proto/generated/tensorflow/core/framework/tensor_shape.pb.dart';
4091+ @@ -164,21 +165,107 @@ class _TFServingDemoState extends State<TFServingDemo> {
40804092 // For iOS emulator, desktop and web platforms
40814093 _server = '127.0.0.1';
40824094 }
40834095 - // TODO: build _vocabMap if empty
4084-
4096+
40854097 - // TODO: tokenize the input sentence.
40864098 + if (_vocabMap.isEmpty) {
40874099 + final vocabFileString = await rootBundle.loadString(vocabFile);
@@ -4113,7 +4125,7 @@ steps:
41134125 + break;
41144126 + }
41154127 + }
4116-
4128+
41174129 if (_connectionMode == ConnectionModeType.rest) {
41184130 - // TODO: create and send the REST request
41194131 + final response = await http.post(
@@ -4122,7 +4134,7 @@ steps:
41224134 + 'instances': [_tokenIndices],
41234135 + }),
41244136 + );
4125-
4137+
41264138 - // TODO: process the REST response
41274139 + if (response.statusCode == 200) {
41284140 + Map<String, dynamic> result =
@@ -4137,38 +4149,46 @@ steps:
41374149 + }
41384150 } else {
41394151 - // TODO: create the gRPC request
4140- + final channel = ClientChannel(_server,
4141- + port: grpcPort,
4142- + options:
4143- + const ChannelOptions(credentials: ChannelCredentials.insecure()));
4144- + _stub = PredictionServiceClient(channel,
4145- + options: CallOptions(timeout: const Duration(seconds: 10)));
4152+ + final channel = ClientChannel(
4153+ + _server,
4154+ + port: grpcPort,
4155+ + options: const ChannelOptions(
4156+ + credentials: ChannelCredentials.insecure(),
4157+ + ),
4158+ + );
4159+ + _stub = PredictionServiceClient(
4160+ + channel,
4161+ + options: CallOptions(timeout: const Duration(seconds: 10)),
4162+ + );
41464163 +
41474164 + ModelSpec modelSpec = ModelSpec(
41484165 + name: 'spam-detection',
41494166 + signatureName: 'serving_default',
41504167 + );
41514168 +
41524169 + TensorShapeProto_Dim batchDim = TensorShapeProto_Dim(size: Int64(1));
4153- + TensorShapeProto_Dim inputDim =
4154- + TensorShapeProto_Dim(size: Int64(maxSentenceLength));
4155- + TensorShapeProto inputTensorShape =
4156- + TensorShapeProto(dim: [batchDim, inputDim]);
4170+ + TensorShapeProto_Dim inputDim = TensorShapeProto_Dim(
4171+ + size: Int64(maxSentenceLength),
4172+ + );
4173+ + TensorShapeProto inputTensorShape = TensorShapeProto(
4174+ + dim: [batchDim, inputDim],
4175+ + );
41574176 + TensorProto inputTensor = TensorProto(
4158- + dtype: DataType.DT_INT32,
4159- + tensorShape: inputTensorShape,
4160- + intVal: _tokenIndices);
4177+ + dtype: DataType.DT_INT32,
4178+ + tensorShape: inputTensorShape,
4179+ + intVal: _tokenIndices,
4180+ + );
41614181 +
41624182 + // If you train your own model, make sure to update the input and output
41634183 + // tensor names.
41644184 + const inputTensorName = 'input_3';
41654185 + const outputTensorName = 'dense_5';
4166- + PredictRequest request = PredictRequest(
4167- + modelSpec: modelSpec, inputs: {inputTensorName: inputTensor});
4168-
4186+ + PredictRequest request = PredictRequest(modelSpec: modelSpec)
4187+ + .. inputs.addAll( {inputTensorName: inputTensor});
4188+
41694189 - // TODO: send the gRPC request
41704190 + PredictResponse response = await _stub.predict(request);
4171-
4191+
41724192 - // TODO: process the gRPC response
41734193 + if (response.outputs.containsKey(outputTensorName)) {
41744194 + if (response.outputs[outputTensorName]!.floatVal[1] >
0 commit comments