@@ -97,7 +97,8 @@ struct InlineEvent {
9797// / Collect data we may use for training a model.
9898class TrainingLogger final {
9999public:
100- TrainingLogger (StringRef LogFileName, const ModelUnderTrainingRunner *MUTR);
100+ TrainingLogger (StringRef LogFileName, const ModelUnderTrainingRunner *MUTR,
101+ const std::vector<TensorSpec> &FeatureMap);
101102
102103 // / Log one inlining event.
103104 void logInlineEvent (const InlineEvent &Event,
@@ -106,6 +107,8 @@ class TrainingLogger final {
106107private:
107108 StringRef LogFileName;
108109 const ModelUnderTrainingRunner *const MUTR;
110+ const std::vector<TensorSpec> &FeatureMap;
111+
109112 std::unique_ptr<Logger> L;
110113 BitVector Effects;
111114 // / Set these 2 clearly OOB, to make sure we set them later.
@@ -142,9 +145,10 @@ class DevelopmentModeMLInlineAdvisor : public MLInlineAdvisor {
142145public:
143146 DevelopmentModeMLInlineAdvisor (
144147 Module &M, ModuleAnalysisManager &MAM,
145- std::unique_ptr<MLModelRunner> ModelRunner,
146- std::function<bool (CallBase &)> GetDefaultAdvice,
147- std::unique_ptr<TrainingLogger> Logger);
148+ std::function<
149+ std::unique_ptr<MLModelRunner>(const std::vector<TensorSpec> &)>
150+ GetModelRunner,
151+ std::function<bool (CallBase &)> GetDefaultAdvice);
148152
149153 size_t getTotalSizeEstimate ();
150154
@@ -258,9 +262,13 @@ static const std::vector<TensorSpec> TrainingOnlyFeatures{
258262 TensorSpec::createSpec<float >(TFFeedPrefix + " reward" , {1 }),
259263 TensorSpec::createSpec<int32_t >(TFFeedPrefix + " step_type" , {1 })};
260264
261- static const std::vector<TensorSpec> getInputFeatures () {
265+ // add TFFeedPrefix to the names and also add the "TrainingOnlyFeatures" which
266+ // the model runner needs to see present. We don't set them ourselves or
267+ // interact with them.
268+ static const std::vector<TensorSpec>
269+ convertInputFeatures (const std::vector<TensorSpec> &OriginalFeatures) {
262270 std::vector<TensorSpec> InputSpecs;
263- for (const auto &Feature : getFeatureMap () )
271+ for (const auto &Feature : OriginalFeatures )
264272 InputSpecs.push_back (TensorSpec (TFFeedPrefix + Feature.name (), Feature));
265273 append_range (InputSpecs, TrainingOnlyFeatures);
266274 return InputSpecs;
@@ -269,8 +277,9 @@ static const std::vector<TensorSpec> getInputFeatures() {
269277} // namespace
270278
271279TrainingLogger::TrainingLogger (StringRef LogFileName,
272- const ModelUnderTrainingRunner *MUTR)
273- : LogFileName(LogFileName), MUTR(MUTR) {
280+ const ModelUnderTrainingRunner *MUTR,
281+ const std::vector<TensorSpec> &FeatureMap)
282+ : LogFileName(LogFileName), MUTR(MUTR), FeatureMap(FeatureMap) {
274283 // The first output is the inlining decision.
275284 std::vector<TensorSpec> FT (getFeatureMap ().begin (), getFeatureMap ().end ());
276285
@@ -327,15 +336,19 @@ void TrainingLogger::logInlineEvent(const InlineEvent &Event,
327336
328337DevelopmentModeMLInlineAdvisor::DevelopmentModeMLInlineAdvisor (
329338 Module &M, ModuleAnalysisManager &MAM,
330- std::unique_ptr<MLModelRunner> ModelRunner,
331- std::function<bool (CallBase &)> GetDefaultAdvice,
332- std::unique_ptr<TrainingLogger> Logger)
333- : MLInlineAdvisor(M, MAM, std::move(ModelRunner), GetDefaultAdvice),
339+ std::function<
340+ std::unique_ptr<MLModelRunner>(const std::vector<TensorSpec> &)>
341+ GetModelRunner,
342+ std::function<bool(CallBase &)> GetDefaultAdvice)
343+ : MLInlineAdvisor(M, MAM, GetModelRunner, GetDefaultAdvice),
334344 IsDoingInference(isa<ModelUnderTrainingRunner>(getModelRunner())),
335- Logger(std::move(Logger)),
336345 InitialNativeSize(isLogging() ? getTotalSizeEstimate() : 0),
337346 CurrentNativeSize(InitialNativeSize) {
338347 // We cannot have the case of neither inference nor logging.
348+ if (!TrainingLog.empty ())
349+ Logger = std::make_unique<TrainingLogger>(
350+ TrainingLog, dyn_cast<ModelUnderTrainingRunner>(ModelRunner.get ()),
351+ getFeatureMap ());
339352 assert (IsDoingInference || isLogging ());
340353}
341354
@@ -401,21 +414,22 @@ std::unique_ptr<InlineAdvisor> llvm::getDevelopmentModeAdvisor(
401414 Module &M, ModuleAnalysisManager &MAM,
402415 std::function<bool (CallBase &)> GetDefaultAdvice) {
403416 auto &Ctx = M.getContext ();
404- std::unique_ptr<MLModelRunner> Runner;
405- if (TFModelUnderTrainingPath.empty ())
406- Runner.reset (new NoInferenceModelRunner (Ctx, getInputFeatures ()));
407- else
408- Runner = ModelUnderTrainingRunner::createAndEnsureValid (
409- Ctx, TFModelUnderTrainingPath, DecisionName, getInputFeatures (),
410- TFOutputSpecOverride);
411- if (!Runner)
412- return nullptr ;
413- std::unique_ptr<TrainingLogger> Logger;
414- if (!TrainingLog.empty ())
415- Logger = std::make_unique<TrainingLogger>(
416- TrainingLog, dyn_cast<ModelUnderTrainingRunner>(Runner.get ()));
417-
418- return std::make_unique<DevelopmentModeMLInlineAdvisor>(
419- M, MAM, std::move (Runner), GetDefaultAdvice, std::move (Logger));
417+ auto RunnerFactory = [&](const std::vector<TensorSpec> &InputFeatures)
418+ -> std::unique_ptr<MLModelRunner> {
419+ std::unique_ptr<MLModelRunner> Runner;
420+ const std::vector<TensorSpec> ConvertedFeatures =
421+ convertInputFeatures (InputFeatures);
422+ if (TFModelUnderTrainingPath.empty ())
423+ Runner.reset (new NoInferenceModelRunner (Ctx, ConvertedFeatures));
424+ else
425+ Runner = ModelUnderTrainingRunner::createAndEnsureValid (
426+ Ctx, TFModelUnderTrainingPath, DecisionName, ConvertedFeatures,
427+ TFOutputSpecOverride);
428+ if (!Runner)
429+ return nullptr ;
430+ return Runner;
431+ };
432+ return std::make_unique<DevelopmentModeMLInlineAdvisor>(M, MAM, RunnerFactory,
433+ GetDefaultAdvice);
420434}
421435#endif // defined(LLVM_HAVE_TFLITE)
0 commit comments