Skip to content

Commit 9181467

Browse files
fixes one dal dispatching issues (#6547)
1 parent a06dadc commit 9181467

File tree

4 files changed

+23
-45
lines changed

4 files changed

+23
-45
lines changed

src/Microsoft.ML.Data/MLContext.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.Reflection;
78
using Microsoft.ML.Data;
89
using Microsoft.ML.Runtime;
910

@@ -171,5 +172,24 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)
171172

172173
[BestFriend]
173174
internal void CancelExecution() => ((ICancelable)_env).CancelExecution();
175+
176+
[BestFriend]
177+
internal static readonly bool OneDalDispatchingEnabled = InitializeOneDalDispatchingEnabled();
178+
179+
private static bool InitializeOneDalDispatchingEnabled()
180+
{
181+
try
182+
{
183+
var asm = Assembly.Load("Microsoft.ML.OneDal");
184+
var type = asm.GetType("Microsoft.ML.OneDal.OneDalUtils");
185+
var method = type.GetMethod("IsDispatchingEnabled", BindingFlags.Public | BindingFlags.Static | BindingFlags.NonPublic);
186+
var result = method.Invoke(null, null);
187+
return (bool)result;
188+
}
189+
catch
190+
{
191+
return false;
192+
}
193+
}
174194
}
175195
}

src/Microsoft.ML.FastTree/RandomForestClassification.cs

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ private protected override FastForestBinaryModelParameters TrainModelCore(TrainC
224224
FeatureCount = trainData.Schema.Feature.Value.Type.GetValueCount();
225225
ConvertData(trainData);
226226

227-
if (!trainData.Schema.Weight.HasValue && IsDispatchingToOneDalEnabled())
227+
if (!trainData.Schema.Weight.HasValue && MLContext.OneDalDispatchingEnabled)
228228
{
229229
if (FastTreeTrainerOptions.FeatureFraction != 1.0)
230230
{
@@ -262,20 +262,6 @@ public static extern unsafe int DecisionForestClassificationCompute(
262262
void* lteChildPtr, void* gtChildPtr, void* splitFeaturePtr, void* featureThresholdPtr, void* leafValuesPtr, void* modelPtr);
263263
}
264264

265-
[BestFriend]
266-
private bool IsDispatchingToOneDalEnabled()
267-
{
268-
try
269-
{
270-
return OneDalUtils.IsDispatchingEnabled();
271-
}
272-
catch (Exception)
273-
{
274-
// Bail to default implementation upon encountering any situation where dispatch failed
275-
return false;
276-
}
277-
}
278-
279265
[BestFriend]
280266
private void TrainCoreOneDal(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)
281267
{

src/Microsoft.ML.FastTree/RandomForestRegression.cs

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ private protected override FastForestRegressionModelParameters TrainModelCore(Tr
363363
FeatureCount = trainData.Schema.Feature.Value.Type.GetValueCount();
364364
ConvertData(trainData);
365365

366-
if (!trainData.Schema.Weight.HasValue && IsDispatchingToOneDalEnabled())
366+
if (!trainData.Schema.Weight.HasValue && MLContext.OneDalDispatchingEnabled)
367367
{
368368
if (FastTreeTrainerOptions.FeatureFraction != 1.0)
369369
{
@@ -395,20 +395,6 @@ public static extern unsafe int DecisionForestRegressionCompute(
395395
void* lteChildPtr, void* gtChildPtr, void* splitFeaturePtr, void* featureThresholdPtr, void* leafValuesPtr, void* modelPtr);
396396
}
397397

398-
[BestFriend]
399-
private bool IsDispatchingToOneDalEnabled()
400-
{
401-
try
402-
{
403-
return OneDalUtils.IsDispatchingEnabled();
404-
}
405-
catch (Exception)
406-
{
407-
// fall back to original implementation for any circumstance that prevents dispatching
408-
return false;
409-
}
410-
}
411-
412398
[BestFriend]
413399
private void TrainCoreOneDal(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)
414400
{

src/Microsoft.ML.Mkl.Components/OlsLinearRegression.cs

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -406,20 +406,6 @@ private void ComputeMklRegression(IChannel ch, FloatLabelCursor.Factory cursorFa
406406
xty = null;
407407
}
408408

409-
[BestFriend]
410-
private bool IsDispatchingToOneDalEnabled()
411-
{
412-
try
413-
{
414-
return OneDalUtils.IsDispatchingEnabled();
415-
}
416-
catch (Exception)
417-
{
418-
// Bail to default implementation upon any situation that prevents dispatching
419-
return false;
420-
}
421-
}
422-
423409
private OlsModelParameters TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)
424410
{
425411
Host.AssertValue(ch);
@@ -440,7 +426,7 @@ private OlsModelParameters TrainCore(IChannel ch, FloatLabelCursor.Factory curso
440426
var beta = new Double[m];
441427
Double yMean = 0;
442428

443-
if (IsDispatchingToOneDalEnabled())
429+
if (MLContext.OneDalDispatchingEnabled)
444430
{
445431
ComputeOneDalRegression(ch, cursorFactory, m, ref beta, xtx, ref n, ref yMean);
446432
}

0 commit comments

Comments
 (0)