-
Notifications
You must be signed in to change notification settings - Fork 68
Open
Description
I am trying to summarize a text using the .onnx file from here https://huggingface.co/pszemraj/long-t5-tglobal-base-16384-book-summary.
I am getting OrtValue result and I am trying to get the value but the program crashes every time with no clear error
Lost connection to device.
the Dart compiler exited unexpectedly.
Here is a copy-paste example, I am running it on Linux
Code:
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:onnxruntime/onnxruntime.dart';
void main() {
OrtEnv.instance.init();
runApp(MaterialApp(
home: TextSummarization(),
));
}
class TextSummarization extends StatefulWidget {
@override
_TextSummarizationState createState() => _TextSummarizationState();
}
class _TextSummarizationState extends State<TextSummarization> {
// OrtSession? encoderSession;
// OrtSession? decoderSession;
OrtSession? longT5Session;
String _summary = "";
@override
void initState() {
super.initState();
loadLongT5Session();
}
Future<void> loadLongT5Session() async {
final sessionOptions = OrtSessionOptions();
const assetFileName = 'assets/models/long_t5.onnx';
final rawAssetFile = await rootBundle.load(assetFileName);
final bytes = rawAssetFile.buffer.asUint8List();
longT5Session = OrtSession.fromBuffer(bytes, sessionOptions);
}
Future<List<OrtValue?>?> summeryLongT5(String text) async {
// List<int> tokenizedInput = _tokenizeText(inputText);
List<List<int>> inputList = [
[
947,
19,
3,
9,
418,
13,
1499,
27,
278,
31,
17,
241,
12,
608,
3,
1825,
58,
1,
]
];
List<List<int>> attentionMask = [
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
];
List<List<int>> decoderInputIds = [
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
];
OrtValueTensor inputOrt =
OrtValueTensor.createTensorWithDataList(inputList);
OrtValueTensor attentionMaskOrt =
OrtValueTensor.createTensorWithDataList(attentionMask);
OrtValueTensor decoderInputIdsOrt =
OrtValueTensor.createTensorWithDataList(decoderInputIds);
final inputs = {
'input_ids': inputOrt,
'attention_mask': attentionMaskOrt,
'decoder_input_ids': decoderInputIdsOrt,
};
final runOptions = OrtRunOptions();
print('Input names');
print(longT5Session?.inputNames);
List<OrtValue?>? outputs =
await longT5Session?.runAsync(runOptions, inputs);
inputOrt.release();
attentionMaskOrt.release();
runOptions.release();
outputs?.forEach((element) {
element?.release();
});
if (outputs == null || outputs.isEmpty) {
return null;
}
print(outputs);
print('');
OrtValue? b = outputs[0];
try {
Object? d = b?.value;
} catch (e) {
print(e);
return null;
}
print('This is the output');
print(outputs);
return outputs;
// Post-process the output
// String summary = _decodeSummary(outputs);
// setState(() {
// _summary = summary;
// });
}
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: Text('Text Summarization with ONNX'),
),
body: Padding(
padding: const EdgeInsets.all(16.0),
child: Column(
children: [
TextField(
onChanged: (text) {
setState(() {
_summary = "Generating summary...";
});
summeryLongT5(text);
},
decoration: InputDecoration(
hintText: "Enter text to summarize",
border: OutlineInputBorder(),
),
),
SizedBox(height: 20),
TextButton(
onPressed: () => summeryLongT5('Your'),
child: Text('Press to impress'),
),
SizedBox(height: 20),
Text('Summary: $_summary'),
],
),
),
);
}
}Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels