|
| 1 | +/******************************************************************************* |
| 2 | + * Copyright (c) 2019 Konduit K.K. |
| 3 | + * |
| 4 | + * This program and the accompanying materials are made available under the |
| 5 | + * terms of the Apache License, Version 2.0 which is available at |
| 6 | + * https://www.apache.org/licenses/LICENSE-2.0. |
| 7 | + * |
| 8 | + * Unless required by applicable law or agreed to in writing, software |
| 9 | + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| 10 | + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| 11 | + * License for the specific language governing permissions and limitations |
| 12 | + * under the License. |
| 13 | + * |
| 14 | + * SPDX-License-Identifier: Apache-2.0 |
| 15 | + ******************************************************************************/ |
| 16 | +package org.deeplearning4j.examples.modelimport.tensorflow; |
| 17 | + |
| 18 | +import org.apache.commons.io.FileUtils; |
| 19 | +import org.deeplearning4j.examples.download.DownloaderUtility; |
| 20 | +import org.deeplearning4j.iterator.BertIterator; |
| 21 | +import org.deeplearning4j.iterator.provider.CollectionLabeledPairSentenceProvider; |
| 22 | +import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory; |
| 23 | +import org.nd4j.autodiff.listeners.records.EvaluationRecord; |
| 24 | +import org.nd4j.autodiff.samediff.SDVariable; |
| 25 | +import org.nd4j.autodiff.samediff.SameDiff; |
| 26 | +import org.nd4j.autodiff.samediff.TrainingConfig; |
| 27 | +import org.nd4j.autodiff.samediff.transform.*; |
| 28 | +import org.nd4j.evaluation.classification.Evaluation; |
| 29 | +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; |
| 30 | +import org.nd4j.imports.tensorflow.TFImportOverride; |
| 31 | +import org.nd4j.imports.tensorflow.TFOpImportFilter; |
| 32 | +import org.nd4j.linalg.api.buffer.DataType; |
| 33 | +import org.nd4j.linalg.api.ndarray.INDArray; |
| 34 | +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; |
| 35 | +import org.nd4j.linalg.learning.config.Sgd; |
| 36 | +import org.nd4j.linalg.primitives.Pair; |
| 37 | +import org.nd4j.resources.Downloader; |
| 38 | +import org.slf4j.Logger; |
| 39 | +import org.slf4j.LoggerFactory; |
| 40 | + |
| 41 | +import java.io.File; |
| 42 | +import java.net.URL; |
| 43 | +import java.nio.charset.StandardCharsets; |
| 44 | +import java.util.*; |
| 45 | + |
| 46 | + |
| 47 | +/** |
| 48 | + * This example demonstrates how to run inference on a fine tuned BERT model in SameDiff, where the fine tuning happens outside in Tensorflow |
| 49 | + * <p> |
| 50 | + * Details on the fine tuning: |
| 51 | + * The pretrained BERT model is fine tuned on the Microsoft Research Paraphrase Corpus (MRPC) corpus to |
| 52 | + * classify sentence pairs as like or unlike. This fine tuned model is then frozen to give a protobuf(pb) file. |
| 53 | + * More details about how this model and frozen pb are generated can be found here: |
| 54 | + * https://github.com/KonduitAI/dl4j-dev-tools/tree/master/import-tests/model_zoo/bert |
| 55 | + * <p> |
| 56 | + * The model is evaluated on the test dataset. |
| 57 | + * As expected it has about 84% accuracy and 0.89 F1 score reflecting the training metrics during the fine tuning. |
| 58 | + * Example also demonstrates how to do inference on a single minibatch. |
| 59 | + * <p> |
| 60 | + * Similar to all tensorflow models the frozen pb is imported into a SameDiff graph. In the case of BERT the graph has to be |
| 61 | + * preprocessed (removing unneeded nodes etc) before inference can be carried out on it. More details with code below. |
| 62 | + */ |
| 63 | +public class BertInferenceExample { |
| 64 | + |
| 65 | + public static Logger log = LoggerFactory.getLogger(BertInferenceExample.class); |
| 66 | + |
| 67 | + public static String bertModelPath; |
| 68 | + //This BERT model uses a FIXED (hardcoded) minibatch size, not dynamic as most models use |
| 69 | + public static final int MINI_BATCH_SIZE = 4; |
| 70 | + public static final int MAX_LENGTH = 128; |
| 71 | + |
| 72 | + public static void main(String[] args) throws Exception { |
| 73 | + |
| 74 | + File frozenBertPB = downloadBERTFineTunedMSPR(); |
| 75 | + |
| 76 | + //replace iterator with placeholder for inputs |
| 77 | + Map<String, TFImportOverride> iterToPlaceholderOverride = overrideIteratorsToPlaceholders(); |
| 78 | + //Don't need the "IteratorV2" node from the graph, hence filtering when importing |
| 79 | + TFOpImportFilter filterNodeIterV2 = filterNodeByName("IteratorV2"); |
| 80 | + SameDiff sd = TFGraphMapper.importGraph(frozenBertPB, iterToPlaceholderOverride, filterNodeIterV2); |
| 81 | + |
| 82 | + //rename replaced placeholders with more appropriate names |
| 83 | + sd.renameVariable("IteratorGetNext", "tokenIdxs"); |
| 84 | + sd.renameVariable("IteratorGetNext:1", "mask"); |
| 85 | + sd.renameVariable("IteratorGetNext:4", "sentenceIdx"); |
| 86 | + |
| 87 | + //remove hard coded dropouts for inference |
| 88 | + sd = removeHardCodedDropOutOps(sd); |
| 89 | + sd.setTrainingConfig(new TrainingConfig.Builder() |
| 90 | + .updater(new Sgd()) |
| 91 | + .dataSetFeatureMapping("tokenIdxs", "sentenceIdx") |
| 92 | + .dataSetFeatureMaskMapping("mask") |
| 93 | + .dataSetLabelMapping("loss/Softmax").build()); |
| 94 | + |
| 95 | + //Downloads test data and sets up the bert iterator correctly |
| 96 | + MultiDataSetIterator iterTest = getMSPRTestIterator(); |
| 97 | + |
| 98 | + System.out.println("\nRunning inference on the test dataset. This might take a while ... depending on your hardware"); |
| 99 | + //Evaluates model on the entire test dataset and prints evaluation stats |
| 100 | + EvaluationRecord evaluationRecord = sd |
| 101 | + .evaluate() |
| 102 | + .data(iterTest) |
| 103 | + .evaluate("loss/Softmax", 0, new Evaluation()) //0 specifies the label index - needed since this is a multidataset iterator, "loss/Softmax" is the output node of interest |
| 104 | + .exec(); |
| 105 | + System.out.println(evaluationRecord.evaluation("loss/Softmax").stats()); |
| 106 | + |
| 107 | + //Four sentence pairs to run inference on |
| 108 | + List<Pair<String, String>> sentencePairs = new ArrayList<>(); |
| 109 | + sentencePairs.add(new Pair<>("The broader Standard & Poor's 500 Index <.SPX> was 0.46 points lower, or 0.05 percent, at 997.02.", "The technology-laced Nasdaq Composite Index .IXIC was up 7.42 points, or 0.45 percent, at 1,653.44.")); |
| 110 | + sentencePairs.add(new Pair<>("Shares in BA were down 1.5 percent at 168 pence by 1420 GMT, off a low of 164p, in a slightly stronger overall London market.", "Shares in BA were down three percent at 165-1/4 pence by 0933 GMT, off a low of 164 pence, in a stronger market.")); |
| 111 | + sentencePairs.add(new Pair<>("Last year, Comcast signed 1.5 million new digital cable subscribers.", "Comcast has about 21.3 million cable subscribers, many in the largest U.S. cities.")); |
| 112 | + sentencePairs.add(new Pair<>("Revenue rose 3.9 percent, to $1.63 billion from $1.57 billion.", "The McLean, Virginia-based company said newspaper revenue increased 5 percent to $1.46 billion.")); |
| 113 | + //Featurizes them |
| 114 | + BertIterator bertIter = (BertIterator) iterTest; |
| 115 | + Pair<INDArray[], INDArray[]> featurizedWithMasks = bertIter.featurizeSentencePairs(sentencePairs); |
| 116 | + INDArray[] features = featurizedWithMasks.getFirst(); |
| 117 | + INDArray[] masks = featurizedWithMasks.getSecond(); |
| 118 | + |
| 119 | + System.out.println("\nRunning inference on a single minibatch with sentence pairs as follows:"); |
| 120 | + for (Pair<String,String> sentencePair: sentencePairs) { |
| 121 | + System.out.println("\t" + sentencePair.getFirst() + "\t" + sentencePair.getSecond()); |
| 122 | + } |
| 123 | + //Runs inference |
| 124 | + INDArray output = sd.batchOutput() |
| 125 | + .input("tokenIdxs", features[0]) |
| 126 | + .input("sentenceIdx", features[1]) |
| 127 | + .input("mask", masks[0]) |
| 128 | + .output("loss/Softmax") |
| 129 | + .outputSingle(); |
| 130 | + System.out.println("\n" + output); |
| 131 | + |
| 132 | + } |
| 133 | + |
| 134 | + |
| 135 | + private static MultiDataSetIterator getMSPRTestIterator() throws Exception { |
| 136 | + List<String> sentencesL = new ArrayList<>(); |
| 137 | + List<String> sentencesR = new ArrayList<>(); |
| 138 | + List<String> labels = new ArrayList<>(); |
| 139 | + |
| 140 | + URL testDataURL = new URL("https://raw.githubusercontent.com/lanwuwei/SPM_toolkit/master/PWIM/data/msrp/test/msr_paraphrase_test.txt"); |
| 141 | + String testFileName = "msr_paraphrase_test.txt"; |
| 142 | + String fileMD5 = "b7e1ed816b22c76d51e0f4bd87768056"; |
| 143 | + //retry download five times |
| 144 | + Downloader.download(testFileName, testDataURL, new File(bertModelPath, testFileName), fileMD5, 5); |
| 145 | + |
| 146 | + List<String> lines = FileUtils.readLines(new File(bertModelPath, testFileName), "utf-8"); |
| 147 | + for (int i = 0; i < lines.size(); i++) { |
| 148 | + if (i == 0) continue; //skip header |
| 149 | + String line = lines.get(i); |
| 150 | + String[] columns = line.split("\t"); |
| 151 | + //Quality #1 ID #2 ID #1 String #2 String |
| 152 | + labels.add(columns[0]); |
| 153 | + sentencesL.add(columns[3]); |
| 154 | + sentencesR.add(columns[4]); |
| 155 | + } |
| 156 | + |
| 157 | + CollectionLabeledPairSentenceProvider labeledPairSentenceProvider = new CollectionLabeledPairSentenceProvider(sentencesL, sentencesR, labels, null); |
| 158 | + File wordPieceTokens = new File(bertModelPath, "uncased/uncased_L-12_H-768_A-12/vocab.txt"); |
| 159 | + |
| 160 | + BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(wordPieceTokens, true, true, StandardCharsets.UTF_8); |
| 161 | + BertIterator b = BertIterator.builder() |
| 162 | + .tokenizer(t) |
| 163 | + .padMinibatches(true) |
| 164 | + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, MAX_LENGTH) |
| 165 | + .minibatchSize(MINI_BATCH_SIZE) |
| 166 | + .sentencePairProvider(labeledPairSentenceProvider) |
| 167 | + .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) |
| 168 | + .vocabMap(t.getVocab()) |
| 169 | + .task(BertIterator.Task.SEQ_CLASSIFICATION) |
| 170 | + .prependToken("[CLS]") |
| 171 | + .appendToken("[SEP]") |
| 172 | + .build(); |
| 173 | + |
| 174 | + return b; |
| 175 | + } |
| 176 | + |
| 177 | + private static File downloadBERTFineTunedMSPR() throws Exception { |
| 178 | + bertModelPath = DownloaderUtility.BERTEXAMPLE.Download(false); |
| 179 | + return new File(bertModelPath, "bert_mrpc_frozen.pb"); |
| 180 | + } |
| 181 | + |
| 182 | + /** |
| 183 | + * These are op import overrides. We skip the IteratorGetNext node and instead create placeholders. |
| 184 | + */ |
| 185 | + private static Map<String, TFImportOverride> overrideIteratorsToPlaceholders() { |
| 186 | + Map<String, TFImportOverride> m = new HashMap<>(); |
| 187 | + m.put("IteratorGetNext", (inputs, controlDepInputs, nodeDef, initWith, attributesForNode, graph) -> { |
| 188 | + return Arrays.asList( |
| 189 | + initWith.placeHolder("IteratorGetNext", DataType.INT, MINI_BATCH_SIZE, MAX_LENGTH), |
| 190 | + initWith.placeHolder("IteratorGetNext:1", DataType.INT, MINI_BATCH_SIZE, MAX_LENGTH), |
| 191 | + initWith.placeHolder("IteratorGetNext:4", DataType.INT, MINI_BATCH_SIZE, MAX_LENGTH) |
| 192 | + ); |
| 193 | + }); |
| 194 | + return m; |
| 195 | + } |
| 196 | + |
| 197 | + private static TFOpImportFilter filterNodeByName(String nodeName) { |
| 198 | + TFOpImportFilter filter = (nodeDef, initWith, attributesForNode, graph) -> { |
| 199 | + return nodeName.equals(nodeDef.getName()); |
| 200 | + }; |
| 201 | + return filter; |
| 202 | + } |
| 203 | + |
| 204 | + /** |
| 205 | + * Modify the network to remove hard-coded dropout operations for inference. |
| 206 | + * Tensorflow/BERT's dropout is implemented as a set of discrete operations - random, mul, div, floor, etc. |
| 207 | + * We need to select all instances of this subgraph, and then remove them from the graph entirely. |
| 208 | + * The subgraph to select are defined with predicates and a sub graph processor that passes input to output is used to replace it |
| 209 | + */ |
| 210 | + private static SameDiff removeHardCodedDropOutOps(SameDiff sd) { |
| 211 | + |
| 212 | + /* Note that in general there are two ways to define subgraphs (larger than 1 operation) for use in GraphTransformUtil |
| 213 | + (a) withInputSubgraph - the input must match this predicate, AND it is added to the subgraph (i.e., matched and is selected to be part of the subgraph) |
| 214 | + (b) withInputMatching - the input must match this predicate, BUT it is NOT added to the subgraph (i.e., must match only) |
| 215 | + In effect, this predicate will match the set of directly connected operations with the following structure: |
| 216 | + (.../dropout/div, .../dropout/Floor) -> (.../dropout/mul) |
| 217 | + (.../dropout/add) -> (.../dropout/Floor) |
| 218 | + (.../dropout/random_uniform) -> (.../dropout/add) |
| 219 | + (.../dropout/random_uniform/mul) -> (.../dropout/random_uniform) |
| 220 | + (.../dropout/random_uniform/RandomUniform, .../dropout/random_uniform/sub) -> (.../dropout/random_uniform/mul) |
| 221 | +
|
| 222 | + Then, for all subgraphs that match this predicate, we will process them (in this case, simply replace the entire subgraph by passing the input to the output) |
| 223 | + NOTE: How do you work out the appropriate subgraph to replace? The simplest approach is to visualize the graph - either in TensorBoard or using SameDiff UI. |
| 224 | + */ |
| 225 | + SubGraphPredicate p = SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/mul")) |
| 226 | + .withInputCount(2) |
| 227 | + .withInputSubgraph(0, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/div"))) |
| 228 | + .withInputSubgraph(1, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/Floor")) |
| 229 | + .withInputSubgraph(0, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/add")) |
| 230 | + .withInputSubgraph(1, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/random_uniform")) |
| 231 | + .withInputSubgraph(0, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/random_uniform/mul")) |
| 232 | + .withInputSubgraph(0, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/random_uniform/RandomUniform"))) |
| 233 | + .withInputSubgraph(1, SubGraphPredicate.withRoot(OpPredicate.nameMatches(".*/dropout/random_uniform/sub"))) |
| 234 | + ) |
| 235 | + ) |
| 236 | + ) |
| 237 | + ); |
| 238 | + /* |
| 239 | + Create the subgraph processor. |
| 240 | + The subgraph processor is applied to each subgraph - i.e., it defines what we should replace it with. |
| 241 | + It's a 2-step process: |
| 242 | + (1) The SubGraphProcessor is applied to define the replacement subgraph (add any new operations, and define the new outputs, etc). |
| 243 | + In this case, we aren't adding any new ops - so we'll just pass the "real" input (pre dropout activations) to the output. |
| 244 | + Note that the number of returned outputs must match the existing number of outputs (1 in this case). |
| 245 | + Immediately after SubgraphProcessor.processSubgraph returns, both the existing subgraph (to be replaced) and new subgraph (just added) |
| 246 | + exist in parallel. |
| 247 | + (2) The existing subgraph is then removed from the graph, leaving only the new subgraph (as defined in processSubgraph method) |
| 248 | + in its place. |
| 249 | + Note that the order of the outputs you return matters! |
| 250 | + If the original outputs are [A,B,C] and you return output variables [X,Y,Z], then anywhere "A" was used as input |
| 251 | + will now use "X"; similarly Y replaces B, and Z replaces C. |
| 252 | + */ |
| 253 | + sd = GraphTransformUtil.replaceSubgraphsMatching(sd, p, new SubGraphProcessor() { |
| 254 | + @Override |
| 255 | + public List<SDVariable> processSubgraph(SameDiff sd, SubGraph subGraph) { |
| 256 | + List<SDVariable> inputs = subGraph.inputs(); |
| 257 | + SDVariable newOut = null; |
| 258 | + for (SDVariable v : inputs) { |
| 259 | + if (v.getVarName().endsWith("/BiasAdd") || v.getVarName().endsWith("/Softmax") || v.getVarName().endsWith("/add_1") || v.getVarName().endsWith("/Tanh")) { |
| 260 | + newOut = v; |
| 261 | + break; |
| 262 | + } |
| 263 | + } |
| 264 | + |
| 265 | + if (newOut != null) { |
| 266 | + return Collections.singletonList(newOut); |
| 267 | + } |
| 268 | + |
| 269 | + throw new RuntimeException("No pre-dropout input variable found"); |
| 270 | + } |
| 271 | + }); |
| 272 | + |
| 273 | + return sd; |
| 274 | + } |
| 275 | +} |
0 commit comments