Skip to content

Commit ff3a274

Browse files
authored
Merge pull request #928 from eraly/eraly_bert
bert inference on fine tuned TF model
2 parents 4ed29a6 + 89b90c7 commit ff3a274

File tree

2 files changed

+337
-43
lines changed

2 files changed

+337
-43
lines changed
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2019 Konduit K.K.
3+
*
4+
* This program and the accompanying materials are made available under the
5+
* terms of the Apache License, Version 2.0 which is available at
6+
* https://www.apache.org/licenses/LICENSE-2.0.
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
10+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
11+
* License for the specific language governing permissions and limitations
12+
* under the License.
13+
*
14+
* SPDX-License-Identifier: Apache-2.0
15+
******************************************************************************/
16+
package org.deeplearning4j.examples.modelimport.tensorflow;
17+
18+
import org.apache.commons.io.FileUtils;
19+
import org.deeplearning4j.examples.download.DownloaderUtility;
20+
import org.deeplearning4j.iterator.BertIterator;
21+
import org.deeplearning4j.iterator.provider.CollectionLabeledPairSentenceProvider;
22+
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
23+
import org.nd4j.autodiff.listeners.records.EvaluationRecord;
24+
import org.nd4j.autodiff.samediff.SDVariable;
25+
import org.nd4j.autodiff.samediff.SameDiff;
26+
import org.nd4j.autodiff.samediff.TrainingConfig;
27+
import org.nd4j.autodiff.samediff.transform.*;
28+
import org.nd4j.evaluation.classification.Evaluation;
29+
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
30+
import org.nd4j.imports.tensorflow.TFImportOverride;
31+
import org.nd4j.imports.tensorflow.TFOpImportFilter;
32+
import org.nd4j.linalg.api.buffer.DataType;
33+
import org.nd4j.linalg.api.ndarray.INDArray;
34+
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
35+
import org.nd4j.linalg.learning.config.Sgd;
36+
import org.nd4j.linalg.primitives.Pair;
37+
import org.nd4j.resources.Downloader;
38+
import org.slf4j.Logger;
39+
import org.slf4j.LoggerFactory;
40+
41+
import java.io.File;
42+
import java.net.URL;
43+
import java.nio.charset.StandardCharsets;
44+
import java.util.*;
45+
46+
47+
/**
48+
* This example demonstrates how to run inference on a fine tuned BERT model in SameDiff, where the fine tuning happens outside in Tensorflow
49+
* <p>
50+
* Details on the fine tuning:
51+
* The pretrained BERT model is fine tuned on the Microsoft Research Paraphrase Corpus (MRPC) corpus to
52+
* classify sentence pairs as like or unlike. This fine tuned model is then frozen to give a protobuf(pb) file.
53+
* More details about how this model and frozen pb are generated can be found here:
54+
* https://github.com/KonduitAI/dl4j-dev-tools/tree/master/import-tests/model_zoo/bert
55+
* <p>
56+
* The model is evaluated on the test dataset.
57+
* As expected it has about 84% accuracy and 0.89 F1 score reflecting the training metrics during the fine tuning.
58+
* Example also demonstrates how to do inference on a single minibatch.
59+
* <p>
60+
* Similar to all tensorflow models the frozen pb is imported into a SameDiff graph. In the case of BERT the graph has to be
61+
* preprocessed (removing unneeded nodes etc) before inference can be carried out on it. More details with code below.
62+
*/
63+
public class BertInferenceExample {
64+
65+
public static Logger log = LoggerFactory.getLogger(BertInferenceExample.class);
66+
67+
public static String bertModelPath;
68+
//This BERT model uses a FIXED (hardcoded) minibatch size, not dynamic as most models use
69+
public static final int MINI_BATCH_SIZE = 4;
70+
public static final int MAX_LENGTH = 128;
71+
72+
public static void main(String[] args) throws Exception {
73+
74+
File frozenBertPB = downloadBERTFineTunedMSPR();
75+
76+
//replace iterator with placeholder for inputs
77+
Map<String, TFImportOverride> iterToPlaceholderOverride = overrideIteratorsToPlaceholders();
78+
//Don't need the "IteratorV2" node from the graph, hence filtering when importing
79+
TFOpImportFilter filterNodeIterV2 = filterNodeByName("IteratorV2");
80+
SameDiff sd = TFGraphMapper.importGraph(frozenBertPB, iterToPlaceholderOverride, filterNodeIterV2);
81+
82+
//rename replaced placeholders with more appropriate names
83+
sd.renameVariable("IteratorGetNext", "tokenIdxs");
84+
sd.renameVariable("IteratorGetNext:1", "mask");
85+
sd.renameVariable("IteratorGetNext:4", "sentenceIdx");
86+
87+
//remove hard coded dropouts for inference
88+
sd = removeHardCodedDropOutOps(sd);
89+
sd.setTrainingConfig(new TrainingConfig.Builder()
90+
.updater(new Sgd())
91+
.dataSetFeatureMapping("tokenIdxs", "sentenceIdx")
92+
.dataSetFeatureMaskMapping("mask")
93+
.dataSetLabelMapping("loss/Softmax").build());
94+
95+
//Downloads test data and sets up the bert iterator correctly
96+
MultiDataSetIterator iterTest = getMSPRTestIterator();
97+
98+
System.out.println("\nRunning inference on the test dataset. This might take a while ... depending on your hardware");
99+
//Evaluates model on the entire test dataset and prints evaluation stats
100+
EvaluationRecord evaluationRecord = sd
101+
.evaluate()
102+
.data(iterTest)
103+
.evaluate("loss/Softmax", 0, new Evaluation()) //0 specifies the label index - needed since this is a multidataset iterator, "loss/Softmax" is the output node of interest
104+
.exec();
105+
System.out.println(evaluationRecord.evaluation("loss/Softmax").stats());
106+
107+
//Four sentence pairs to run inference on
108+
List<Pair<String, String>> sentencePairs = new ArrayList<>();
109+
sentencePairs.add(new Pair<>("The broader Standard & Poor's 500 Index <.SPX> was 0.46 points lower, or 0.05 percent, at 997.02.", "The technology-laced Nasdaq Composite Index .IXIC was up 7.42 points, or 0.45 percent, at 1,653.44."));
110+
sentencePairs.add(new Pair<>("Shares in BA were down 1.5 percent at 168 pence by 1420 GMT, off a low of 164p, in a slightly stronger overall London market.", "Shares in BA were down three percent at 165-1/4 pence by 0933 GMT, off a low of 164 pence, in a stronger market."));
111+
sentencePairs.add(new Pair<>("Last year, Comcast signed 1.5 million new digital cable subscribers.", "Comcast has about 21.3 million cable subscribers, many in the largest U.S. cities."));
112+
sentencePairs.add(new Pair<>("Revenue rose 3.9 percent, to $1.63 billion from $1.57 billion.", "The McLean, Virginia-based company said newspaper revenue increased 5 percent to $1.46 billion."));
113+
//Featurizes them
114+
BertIterator bertIter = (BertIterator) iterTest;
115+
Pair<INDArray[], INDArray[]> featurizedWithMasks = bertIter.featurizeSentencePairs(sentencePairs);
116+
INDArray[] features = featurizedWithMasks.getFirst();
117+
INDArray[] masks = featurizedWithMasks.getSecond();
118+
119+
System.out.println("\nRunning inference on a single minibatch with sentence pairs as follows:");
120+
for (Pair<String,String> sentencePair: sentencePairs) {
121+
System.out.println("\t" + sentencePair.getFirst() + "\t" + sentencePair.getSecond());
122+
}
123+
//Runs inference
124+
INDArray output = sd.batchOutput()
125+
.input("tokenIdxs", features[0])
126+
.input("sentenceIdx", features[1])
127+
.input("mask", masks[0])
128+
.output("loss/Softmax")
129+
.outputSingle();
130+
System.out.println("\n" + output);
131+
132+
}
133+
134+
135+
private static MultiDataSetIterator getMSPRTestIterator() throws Exception {
136+
List<String> sentencesL = new ArrayList<>();
137+
List<String> sentencesR = new ArrayList<>();
138+
List<String> labels = new ArrayList<>();
139+
140+
URL testDataURL = new URL("https://raw.githubusercontent.com/lanwuwei/SPM_toolkit/master/PWIM/data/msrp/test/msr_paraphrase_test.txt");
141+
String testFileName = "msr_paraphrase_test.txt";
142+
String fileMD5 = "b7e1ed816b22c76d51e0f4bd87768056";
143+
//retry download five times
144+
Downloader.download(testFileName, testDataURL, new File(bertModelPath, testFileName), fileMD5, 5);
145+
146+
List<String> lines = FileUtils.readLines(new File(bertModelPath, testFileName), "utf-8");
147+
for (int i = 0; i < lines.size(); i++) {
148+
if (i == 0) continue; //skip header
149+
String line = lines.get(i);
150+
String[] columns = line.split("\t");
151+
//Quality #1 ID #2 ID #1 String #2 String
152+
labels.add(columns[0]);
153+
sentencesL.add(columns[3]);
154+
sentencesR.add(columns[4]);
155+
}
156+
157+
CollectionLabeledPairSentenceProvider labeledPairSentenceProvider = new CollectionLabeledPairSentenceProvider(sentencesL, sentencesR, labels, null);
158+
File wordPieceTokens = new File(bertModelPath, "uncased/uncased_L-12_H-768_A-12/vocab.txt");
159+
160+
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(wordPieceTokens, true, true, StandardCharsets.UTF_8);
161+
BertIterator b = BertIterator.builder()
162+
.tokenizer(t)
163+
.padMinibatches(true)
164+
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, MAX_LENGTH)
165+
.minibatchSize(MINI_BATCH_SIZE)
166+
.sentencePairProvider(labeledPairSentenceProvider)
167+
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
168+
.vocabMap(t.getVocab())
169+
.task(BertIterator.Task.SEQ_CLASSIFICATION)
170+
.prependToken("[CLS]")
171+
.appendToken("[SEP]")
172+
.build();
173+
174+
return b;
175+
}
176+
177+
private static File downloadBERTFineTunedMSPR() throws Exception {
178+
bertModelPath = DownloaderUtility.BERTEXAMPLE.Download(false);
179+
return new File(bertModelPath, "bert_mrpc_frozen.pb");
180+
}
181+
182+
/**
183+
* These are op import overrides. We skip the IteratorGetNext node and instead create placeholders.
184+
*/
185+
private static Map<String, TFImportOverride> overrideIteratorsToPlaceholders() {
186+
Map<String, TFImportOverride> m = new HashMap<>();
187+
m.put("IteratorGetNext", (inputs, controlDepInputs, nodeDef, initWith, attributesForNode, graph) -> {
188+
return Arrays.asList(
189+
initWith.placeHolder("IteratorGetNext", DataType.INT, MINI_BATCH_SIZE, MAX_LENGTH),
190+
initWith.placeHolder("IteratorGetNext:1", DataType.INT, MINI_BATCH_SIZE, MAX_LENGTH),
191+
initWith.placeHolder("IteratorGetNext:4", DataType.INT, MINI_BATCH_SIZE, MAX_LENGTH)
192+
);
193+
});
194+
return m;
195+
}
196+
197+
private static TFOpImportFilter filterNodeByName(String nodeName) {
198+
TFOpImportFilter filter = (nodeDef, initWith, attributesForNode, graph) -> {
199+
return nodeName.equals(nodeDef.getName());
200+
};
201+
return filter;
202+
}
203+
204+
/**
205+
* Modify the network to remove hard-coded dropout operations for inference.
206+
* Tensorflow/BERT's dropout is implemented as a set of discrete operations - random, mul, div, floor, etc.
207+
* We need to select all instances of this subgraph, and then remove them from the graph entirely.
208+
* The subgraph to select are defined with predicates and a sub graph processor that passes input to output is used to replace it
209+
*/
210+
private static SameDiff removeHardCodedDropOutOps(SameDiff sd) {
211+
212+
/* Note that in general there are two ways to define subgraphs (larger than 1 operation) for use in GraphTransformUtil
213+
(a) withInputSubgraph - the input must match this predicate, AND it is added to the subgraph (i.e., matched and is selected to be part of the subgraph)
214+
(b) withInputMatching - the input must match this predicate, BUT it is NOT added to the subgraph (i.e., must match only)
215+
In effect, this predicate will match the set of directly connected operations with the following structure:
216+
(.../dropout/div, .../dropout/Floor) -> (.../dropout/mul)
217+
(.../dropout/add) -> (.../dropout/Floor)
218+
(.../dropout/random_uniform) -> (.../dropout/add)
219+
(.../dropout/random_uniform/mul) -> (.../dropout/random_uniform)
220+
(.../dropout/random_uniform/RandomUniform, .../dropout/random_uniform/sub) -> (.../dropout/random_uniform/mul)
221+
222+
Then, for all subgraphs that match this predicate, we will process them (in this case, simply replace the entire subgraph by passing the input to the output)
223+
NOTE: How do you work out the appropriate subgraph to replace? The simplest approach is to visualize the graph - either in TensorBoard or using SameDiff UI.
224+
*/
225+
SubGraphPredicate p = SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/mul"))
226+
.withInputCount(2)
227+
.withInputSubgraph(0, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/div")))
228+
.withInputSubgraph(1, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/Floor"))
229+
.withInputSubgraph(0, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/add"))
230+
.withInputSubgraph(1, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/random_uniform"))
231+
.withInputSubgraph(0, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/random_uniform/mul"))
232+
.withInputSubgraph(0, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/random_uniform/RandomUniform")))
233+
.withInputSubgraph(1, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/random_uniform/sub")))
234+
)
235+
)
236+
)
237+
);
238+
/*
239+
Create the subgraph processor.
240+
The subgraph processor is applied to each subgraph - i.e., it defines what we should replace it with.
241+
It's a 2-step process:
242+
(1) The SubGraphProcessor is applied to define the replacement subgraph (add any new operations, and define the new outputs, etc).
243+
In this case, we aren't adding any new ops - so we'll just pass the "real" input (pre dropout activations) to the output.
244+
Note that the number of returned outputs must match the existing number of outputs (1 in this case).
245+
Immediately after SubgraphProcessor.processSubgraph returns, both the existing subgraph (to be replaced) and new subgraph (just added)
246+
exist in parallel.
247+
(2) The existing subgraph is then removed from the graph, leaving only the new subgraph (as defined in processSubgraph method)
248+
in its place.
249+
Note that the order of the outputs you return matters!
250+
If the original outputs are [A,B,C] and you return output variables [X,Y,Z], then anywhere "A" was used as input
251+
will now use "X"; similarly Y replaces B, and Z replaces C.
252+
*/
253+
sd = GraphTransformUtil.replaceSubgraphsMatching(sd, p, new SubGraphProcessor() {
254+
@Override
255+
public List<SDVariable> processSubgraph(SameDiff sd, SubGraph subGraph) {
256+
List<SDVariable> inputs = subGraph.inputs();
257+
SDVariable newOut = null;
258+
for (SDVariable v : inputs) {
259+
if (v.getVarName().endsWith("/BiasAdd") || v.getVarName().endsWith("/Softmax") || v.getVarName().endsWith("/add_1") || v.getVarName().endsWith("/Tanh")) {
260+
newOut = v;
261+
break;
262+
}
263+
}
264+
265+
if (newOut != null) {
266+
return Collections.singletonList(newOut);
267+
}
268+
269+
throw new RuntimeException("No pre-dropout input variable found");
270+
}
271+
});
272+
273+
return sd;
274+
}
275+
}

0 commit comments

Comments
 (0)