Skip to content

Commit 358653e

Browse files
committed
Update codelab_rebuild.yaml
1 parent 8523672 commit 358653e

File tree

1 file changed

+116
-96
lines changed

1 file changed

+116
-96
lines changed

tfserving-flutter/codelab_rebuild.yaml

Lines changed: 116 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -498,88 +498,92 @@ steps:
498498
colorScheme: ColorScheme.fromSeed(seedColor: Colors.deepPurple),
499499
),
500500
home: Scaffold(
501-
appBar: AppBar(
502-
title: const Text('TF Serving Flutter Demo'),
503-
),
501+
appBar: AppBar(title: const Text('TF Serving Flutter Demo')),
504502
body: Center(
505503
child: Container(
506504
padding: const EdgeInsets.fromLTRB(20, 30, 20, 20),
507505
child: Column(
508-
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
509-
children: [
510-
TextField(
511-
controller: _inputSentenceController,
512-
decoration: const InputDecoration(
513-
border: UnderlineInputBorder(),
514-
hintText: 'Enter a sentence here'),
506+
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
507+
children: [
508+
TextField(
509+
controller: _inputSentenceController,
510+
decoration: const InputDecoration(
511+
border: UnderlineInputBorder(),
512+
hintText: 'Enter a sentence here',
515513
),
516-
Column(
517-
children: <Widget>[
518-
ListTile(
519-
title: const Text('gRPC'),
520-
leading: Radio<ConnectionModeType>(
521-
value: ConnectionModeType.grpc,
522-
groupValue: _connectionMode,
523-
onChanged: (value) {
524-
setState(() {
525-
_connectionMode = value;
526-
});
527-
},
528-
),
514+
),
515+
Column(
516+
children: <Widget>[
517+
ListTile(
518+
title: const Text('gRPC'),
519+
leading: Radio<ConnectionModeType>(
520+
value: ConnectionModeType.grpc,
521+
groupValue: _connectionMode,
522+
onChanged: (value) {
523+
setState(() {
524+
_connectionMode = value;
525+
});
526+
},
529527
),
530-
ListTile(
531-
title: const Text('REST'),
532-
leading: Radio<ConnectionModeType>(
533-
value: ConnectionModeType.rest,
534-
groupValue: _connectionMode,
535-
onChanged: (value) {
536-
setState(() {
537-
_connectionMode = value;
538-
});
539-
},
540-
),
528+
),
529+
ListTile(
530+
title: const Text('REST'),
531+
leading: Radio<ConnectionModeType>(
532+
value: ConnectionModeType.rest,
533+
groupValue: _connectionMode,
534+
onChanged: (value) {
535+
setState(() {
536+
_connectionMode = value;
537+
});
538+
},
541539
),
542-
],
543-
),
544-
Row(
545-
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
546-
children: [
547-
FilledButton(
548-
style: FilledButton.styleFrom(
549-
textStyle: const TextStyle(fontSize: 18),
550-
),
551-
onPressed: () {
552-
setState(() {
553-
_futurePrediction = predict();
554-
});
555-
},
556-
child: const Text('Classify')),
557-
FilledButton(
558-
style: FilledButton.styleFrom(
559-
textStyle: const TextStyle(fontSize: 18),
560-
),
561-
onPressed: () {
562-
setState(() {
563-
_futurePrediction =
564-
Future<String>.value(initialPrompt);
565-
_inputSentenceController.clear();
566-
});
567-
},
568-
child: const Text('Reset'))
569-
]),
570-
FutureBuilder<String>(
571-
future: _futurePrediction,
572-
builder: (context, snapshot) {
573-
if (snapshot.hasData) {
574-
return Text(snapshot.data!);
575-
} else if (snapshot.hasError) {
576-
return Text('${snapshot.error}');
577-
}
578-
// By default, show a loading spinner.
579-
return const CircularProgressIndicator();
580-
},
581-
),
582-
]),
540+
),
541+
],
542+
),
543+
Row(
544+
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
545+
children: [
546+
FilledButton(
547+
style: FilledButton.styleFrom(
548+
textStyle: const TextStyle(fontSize: 18),
549+
),
550+
onPressed: () {
551+
setState(() {
552+
_futurePrediction = predict();
553+
});
554+
},
555+
child: const Text('Classify'),
556+
),
557+
FilledButton(
558+
style: FilledButton.styleFrom(
559+
textStyle: const TextStyle(fontSize: 18),
560+
),
561+
onPressed: () {
562+
setState(() {
563+
_futurePrediction = Future<String>.value(
564+
initialPrompt,
565+
);
566+
_inputSentenceController.clear();
567+
});
568+
},
569+
child: const Text('Reset'),
570+
),
571+
],
572+
),
573+
FutureBuilder<String>(
574+
future: _futurePrediction,
575+
builder: (context, snapshot) {
576+
if (snapshot.hasData) {
577+
return Text(snapshot.data!);
578+
} else if (snapshot.hasError) {
579+
return Text('${snapshot.error}');
580+
}
581+
// By default, show a loading spinner.
582+
return const CircularProgressIndicator();
583+
},
584+
),
585+
],
586+
),
583587
),
584588
),
585589
),
@@ -4076,12 +4080,20 @@ steps:
40764080
patch-u: |
40774081
--- b/tfserving-flutter/codelab2/finished/lib/main.dart
40784082
+++ a/tfserving-flutter/codelab2/finished/lib/main.dart
4079-
@@ -160,21 +160,99 @@ class _TFServingDemoState extends State<TFServingDemo> {
4083+
@@ -7,6 +7,7 @@ import 'package:flutter/material.dart';
4084+
import 'package:flutter/services.dart' show rootBundle;
4085+
import 'package:grpc/grpc.dart';
4086+
import 'package:http/http.dart' as http;
4087+
+import 'package:protobuf/protobuf.dart';
4088+
4089+
import 'proto/generated/tensorflow/core/framework/tensor.pb.dart';
4090+
import 'proto/generated/tensorflow/core/framework/tensor_shape.pb.dart';
4091+
@@ -164,21 +165,107 @@ class _TFServingDemoState extends State<TFServingDemo> {
40804092
// For iOS emulator, desktop and web platforms
40814093
_server = '127.0.0.1';
40824094
}
40834095
- // TODO: build _vocabMap if empty
4084-
4096+
40854097
- // TODO: tokenize the input sentence.
40864098
+ if (_vocabMap.isEmpty) {
40874099
+ final vocabFileString = await rootBundle.loadString(vocabFile);
@@ -4113,7 +4125,7 @@ steps:
41134125
+ break;
41144126
+ }
41154127
+ }
4116-
4128+
41174129
if (_connectionMode == ConnectionModeType.rest) {
41184130
- // TODO: create and send the REST request
41194131
+ final response = await http.post(
@@ -4122,7 +4134,7 @@ steps:
41224134
+ 'instances': [_tokenIndices],
41234135
+ }),
41244136
+ );
4125-
4137+
41264138
- // TODO: process the REST response
41274139
+ if (response.statusCode == 200) {
41284140
+ Map<String, dynamic> result =
@@ -4137,38 +4149,46 @@ steps:
41374149
+ }
41384150
} else {
41394151
- // TODO: create the gRPC request
4140-
+ final channel = ClientChannel(_server,
4141-
+ port: grpcPort,
4142-
+ options:
4143-
+ const ChannelOptions(credentials: ChannelCredentials.insecure()));
4144-
+ _stub = PredictionServiceClient(channel,
4145-
+ options: CallOptions(timeout: const Duration(seconds: 10)));
4152+
+ final channel = ClientChannel(
4153+
+ _server,
4154+
+ port: grpcPort,
4155+
+ options: const ChannelOptions(
4156+
+ credentials: ChannelCredentials.insecure(),
4157+
+ ),
4158+
+ );
4159+
+ _stub = PredictionServiceClient(
4160+
+ channel,
4161+
+ options: CallOptions(timeout: const Duration(seconds: 10)),
4162+
+ );
41464163
+
41474164
+ ModelSpec modelSpec = ModelSpec(
41484165
+ name: 'spam-detection',
41494166
+ signatureName: 'serving_default',
41504167
+ );
41514168
+
41524169
+ TensorShapeProto_Dim batchDim = TensorShapeProto_Dim(size: Int64(1));
4153-
+ TensorShapeProto_Dim inputDim =
4154-
+ TensorShapeProto_Dim(size: Int64(maxSentenceLength));
4155-
+ TensorShapeProto inputTensorShape =
4156-
+ TensorShapeProto(dim: [batchDim, inputDim]);
4170+
+ TensorShapeProto_Dim inputDim = TensorShapeProto_Dim(
4171+
+ size: Int64(maxSentenceLength),
4172+
+ );
4173+
+ TensorShapeProto inputTensorShape = TensorShapeProto(
4174+
+ dim: [batchDim, inputDim],
4175+
+ );
41574176
+ TensorProto inputTensor = TensorProto(
4158-
+ dtype: DataType.DT_INT32,
4159-
+ tensorShape: inputTensorShape,
4160-
+ intVal: _tokenIndices);
4177+
+ dtype: DataType.DT_INT32,
4178+
+ tensorShape: inputTensorShape,
4179+
+ intVal: _tokenIndices,
4180+
+ );
41614181
+
41624182
+ // If you train your own model, make sure to update the input and output
41634183
+ // tensor names.
41644184
+ const inputTensorName = 'input_3';
41654185
+ const outputTensorName = 'dense_5';
4166-
+ PredictRequest request = PredictRequest(
4167-
+ modelSpec: modelSpec, inputs: {inputTensorName: inputTensor});
4168-
4186+
+ PredictRequest request = PredictRequest(modelSpec: modelSpec)
4187+
+ ..inputs.addAll({inputTensorName: inputTensor});
4188+
41694189
- // TODO: send the gRPC request
41704190
+ PredictResponse response = await _stub.predict(request);
4171-
4191+
41724192
- // TODO: process the gRPC response
41734193
+ if (response.outputs.containsKey(outputTensorName)) {
41744194
+ if (response.outputs[outputTensorName]!.floatVal[1] >

0 commit comments

Comments
 (0)