Skip to content

Commit 4ad254e

Browse files
kberg011kberg0
authored andcommitted
Alter .mlfiles to be recognized azassets
Signed-off-by: kberg-amzn <[email protected]>
1 parent 027a8d5 commit 4ad254e

22 files changed

+457
-137
lines changed

Gems/MachineLearning/Code/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ if(PAL_TRAIT_BUILD_HOST_TOOLS)
156156
BUILD_DEPENDENCIES
157157
PUBLIC
158158
AZ::AzToolsFramework
159-
$<TARGET_OBJECTS:Gem::${gem_name}.Private.Object>
159+
Gem::${gem_name}.Private.Object
160160
)
161161

162162
ly_add_target(
@@ -180,8 +180,8 @@ if(PAL_TRAIT_BUILD_HOST_TOOLS)
180180
# By default, we will specify that the above target ${gem_name} would be used by
181181
# Tool and Builder type targets when this gem is enabled. If you don't want it
182182
# active in Tools or Builders by default, delete one of both of the following lines:
183-
ly_create_alias(NAME ${gem_name}.Tools NAMESPACE Gem TARGETS Gem::${gem_name} Gem::${gem_name}.Debug Gem::ScriptCanvas.Editor)
184-
ly_create_alias(NAME ${gem_name}.Builders NAMESPACE Gem TARGETS Gem::${gem_name} Gem::ScriptCanvas.Editor)
183+
ly_create_alias(NAME ${gem_name}.Tools NAMESPACE Gem TARGETS Gem::${gem_name}.Editor Gem::${gem_name}.Debug Gem::ScriptCanvas.Editor)
184+
ly_create_alias(NAME ${gem_name}.Builders NAMESPACE Gem TARGETS Gem::${gem_name}.Editor Gem::ScriptCanvas.Editor)
185185

186186
# For the Tools and Builders variants of ${gem_name} Gem, an alias to the ${gem_name}.Editor API target will be made
187187
ly_create_alias(NAME ${gem_name}.Tools.API NAMESPACE Gem TARGETS Gem::${gem_name}.Editor.API)

Gems/MachineLearning/Code/Include/MachineLearning/Types.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,17 @@ namespace MachineLearning
2525
);
2626

2727
AZ_ENUM_CLASS(AssetTypes,
28-
Model,
2928
TestData,
3029
TestLabels,
3130
TrainingData,
3231
TrainingLabels
3332
);
33+
34+
class IAssetPersistenceProxy
35+
{
36+
public:
37+
virtual ~IAssetPersistenceProxy() = default;
38+
virtual bool SaveAsset() = 0;
39+
virtual bool LoadAsset() = 0;
40+
};
3441
}

Gems/MachineLearning/Code/Source/Algorithms/Training.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ namespace MachineLearning
6363
{
6464
InitializeContexts();
6565

66-
const AZStd::size_t totalTrainingSize = m_trainData.GetSampleCount();
67-
6866
// Start training
6967
m_currentEpoch = 0;
7068
m_trainingComplete = false;

Gems/MachineLearning/Code/Source/Assets/MnistDataLoader.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include <AzCore/RTTI/BehaviorContext.h>
1919
#include <AzCore/Serialization/EditContext.h>
2020
#include <AzCore/Serialization/SerializeContext.h>
21-
21+
#pragma optimize("", off)
2222
namespace MachineLearning
2323
{
2424
void MnistDataLoader::Reflect(AZ::ReflectContext* context)
@@ -195,3 +195,4 @@ namespace MachineLearning
195195
return true;
196196
}
197197
}
198+
#pragma optimize("", on)

Gems/MachineLearning/Code/Source/Assets/ModelAsset.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,21 @@
1212

1313
namespace MachineLearning
1414
{
15+
void ModelAsset::Reflect(AZ::ReflectContext* context)
16+
{
17+
if (auto serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
18+
{
19+
serializeContext->Class<ModelAsset>()
20+
->Version(1);
21+
22+
if (AZ::EditContext* editContext = serializeContext->GetEditContext())
23+
{
24+
editContext->Class<ModelAsset>("ML Model Asset", "ML Model Asset")
25+
->ClassElement(AZ::Edit::ClassElements::EditorData, "");
26+
}
27+
}
28+
}
29+
1530
bool ModelAsset::Serialize(AzNetworking::ISerializer& serializer)
1631
{
1732
return serializer.Serialize(m_name, "Name")

Gems/MachineLearning/Code/Source/Assets/ModelAsset.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ namespace MachineLearning
2727
AZ_RTTI(ModelAsset, "{4D8D3782-DC3A-499A-A59D-542B85F5EDE9}", AZ::Data::AssetData);
2828
AZ_CLASS_ALLOCATOR(ModelAsset, AZ::SystemAllocator);
2929

30+
static void Reflect(AZ::ReflectContext* context);
31+
3032
~ModelAsset() = default;
3133

3234
//! Base serialize method for all serializable structures or classes to implement.

Gems/MachineLearning/Code/Source/Components/MultilayerPerceptronComponent.cpp

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <AzCore/RTTI/BehaviorContext.h>
1515
#include <AzCore/Serialization/EditContext.h>
1616
#include <AzCore/Serialization/SerializeContext.h>
17+
#include <AzCore/Console/ILogger.h>
1718

1819
namespace MachineLearning
1920
{
@@ -23,32 +24,21 @@ namespace MachineLearning
2324
{
2425
serializeContext->Class<MultilayerPerceptronComponent>()
2526
->Version(0)
27+
->Field("Asset", &MultilayerPerceptronComponent::m_asset)
2628
->Field("Model", &MultilayerPerceptronComponent::m_model)
2729
;
28-
29-
if (AZ::EditContext* editContext = serializeContext->GetEditContext())
30-
{
31-
editContext->Class<MultilayerPerceptronComponent>("Multilayer Perceptron", "")
32-
->ClassElement(AZ::Edit::ClassElements::EditorData, "")
33-
->Attribute(AZ::Edit::Attributes::Category, "MachineLearning")
34-
->Attribute(AZ::Edit::Attributes::Icon, "Editor/Icons/Components/NeuralNetwork.svg")
35-
->Attribute(AZ::Edit::Attributes::ViewportIcon, "Editor/Icons/Components/Viewport/NeuralNetwork.svg")
36-
->Attribute(AZ::Edit::Attributes::AppearsInAddComponentMenu, AZ_CRC_CE("Game"))
37-
->DataElement(AZ::Edit::UIHandlers::Default, &MultilayerPerceptronComponent::m_model, "Model", "This is the machine-learning model provided by this component")
38-
;
39-
}
4030
}
4131

4232
auto behaviorContext = azrtti_cast<AZ::BehaviorContext*>(context);
4333
if (behaviorContext)
4434
{
45-
behaviorContext->Class<MultilayerPerceptronComponent>("MultilayerPerceptron Component")->
46-
Attribute(AZ::Script::Attributes::Scope, AZ::Script::Attributes::ScopeFlags::Common)->
47-
Attribute(AZ::Script::Attributes::Module, "machineLearning")->
48-
Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)->
49-
Constructor<>()->
50-
Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)->
51-
Property("Model", BehaviorValueProperty(&MultilayerPerceptronComponent::m_model))
35+
behaviorContext->Class<MultilayerPerceptronComponent>("MultilayerPerceptron Component")
36+
->Attribute(AZ::Script::Attributes::Scope, AZ::Script::Attributes::ScopeFlags::Common)
37+
->Attribute(AZ::Script::Attributes::Module, "machineLearning")
38+
->Attribute(AZ::Script::Attributes::ExcludeFrom, AZ::Script::Attributes::ExcludeFlags::ListOnly)
39+
->Constructor<>()
40+
->Attribute(AZ::Script::Attributes::Storage, AZ::Script::Attributes::StorageType::Value)
41+
->Property("Model", BehaviorValueProperty(&MultilayerPerceptronComponent::m_model))
5242
;
5343

5444
behaviorContext->EBus<MultilayerPerceptronComponentRequestBus>("Multilayer perceptron requests")
@@ -79,15 +69,63 @@ namespace MachineLearning
7969
void MultilayerPerceptronComponent::Activate()
8070
{
8171
MultilayerPerceptronComponentRequestBus::Handler::BusConnect(GetEntityId());
72+
AssetChanged();
8273
}
8374

8475
void MultilayerPerceptronComponent::Deactivate()
8576
{
77+
AZ::Data::AssetBus::Handler::BusDisconnect();
8678
MultilayerPerceptronComponentRequestBus::Handler::BusDisconnect();
8779
}
8880

8981
INeuralNetworkPtr MultilayerPerceptronComponent::GetModel()
9082
{
9183
return m_handle;
9284
}
85+
86+
void MultilayerPerceptronComponent::AssetChanged()
87+
{
88+
AZ::Data::AssetBus::Handler::BusDisconnect();
89+
if (m_asset.GetStatus() == AZ::Data::AssetData::AssetStatus::Error ||
90+
m_asset.GetStatus() == AZ::Data::AssetData::AssetStatus::NotLoaded)
91+
{
92+
m_asset.QueueLoad();
93+
}
94+
AZ::Data::AssetBus::Handler::BusConnect(m_asset.GetId());
95+
}
96+
97+
void MultilayerPerceptronComponent::AssetCleared()
98+
{
99+
;
100+
}
101+
102+
void MultilayerPerceptronComponent::OnAssetReady(AZ::Data::Asset<AZ::Data::AssetData> asset)
103+
{
104+
ModelAsset* modelAsset = asset.GetAs<ModelAsset>();
105+
if ((asset == m_asset) && (modelAsset != nullptr))
106+
{
107+
m_model = *modelAsset;
108+
}
109+
}
110+
111+
void MultilayerPerceptronComponent::OnAssetReloaded(AZ::Data::Asset<AZ::Data::AssetData> asset)
112+
{
113+
OnAssetReady(asset);
114+
}
115+
116+
void MultilayerPerceptronComponent::OnAssetError(AZ::Data::Asset<AZ::Data::AssetData> asset)
117+
{
118+
if (asset == m_asset)
119+
{
120+
AZLOG_WARN("OnAssetError: %s", asset.GetHint().c_str());
121+
}
122+
}
123+
124+
void MultilayerPerceptronComponent::OnAssetReloadError(AZ::Data::Asset<AZ::Data::AssetData> asset)
125+
{
126+
if (asset == m_asset)
127+
{
128+
AZLOG_WARN("OnAssetReloadError: %s", asset.GetHint().c_str());
129+
}
130+
}
93131
}

Gems/MachineLearning/Code/Source/Components/MultilayerPerceptronComponent.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
#pragma once
1010

1111
#include <AzCore/Component/Component.h>
12+
#include <AzCore/Asset/AssetCommon.h>
1213
#include <Models/MultilayerPerceptron.h>
14+
#include <Assets/ModelAsset.h>
1315

1416
namespace MachineLearning
1517
{
@@ -25,6 +27,7 @@ namespace MachineLearning
2527

2628
class MultilayerPerceptronComponent
2729
: public AZ::Component
30+
, private AZ::Data::AssetBus::Handler
2831
, public MultilayerPerceptronComponentRequestBus::Handler
2932
{
3033
public:
@@ -52,7 +55,22 @@ namespace MachineLearning
5255

5356
private:
5457

58+
// Edit context callbacks
59+
void AssetChanged();
60+
void AssetCleared();
61+
62+
// AZ::Data::AssetBus ...
63+
void OnAssetReady(AZ::Data::Asset<AZ::Data::AssetData> asset) override;
64+
void OnAssetReloaded(AZ::Data::Asset<AZ::Data::AssetData> asset) override;
65+
void OnAssetError(AZ::Data::Asset<AZ::Data::AssetData> asset) override;
66+
void OnAssetReloadError(AZ::Data::Asset<AZ::Data::AssetData> asset) override;
67+
68+
//! The model asset.
69+
AZ::Data::Asset<ModelAsset> m_asset;
70+
5571
MultilayerPerceptron m_model;
5672
INeuralNetworkPtr m_handle;
73+
74+
friend class MultilayerPerceptronEditorComponent;
5775
};
5876
}

Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugSystemComponent.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,15 @@ namespace MachineLearning
6666
| ImGuiTableFlags_RowBg
6767
| ImGuiTableFlags_NoBordersInBody;
6868

69-
const ImGuiTreeNodeFlags nodeFlags = (ImGuiTreeNodeFlags_Leaf | ImGuiTreeNodeFlags_NoTreePushOnOpen | ImGuiTreeNodeFlags_SpanFullWidth);
70-
7169
IMachineLearning* machineLearning = MachineLearningInterface::Get();
7270
const ModelSet& modelSet = machineLearning->GetModelSet();
7371

7472
ImGui::Text("Total registered models: %u", static_cast<uint32_t>(modelSet.size()));
7573
ImGui::NewLine();
7674

77-
if (ImGui::BeginTable("Model Details", 6, flags))
75+
if (ImGui::BeginTable("Model Details", 5, flags))
7876
{
7977
ImGui::TableSetupColumn("Name", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 32.0f);
80-
ImGui::TableSetupColumn("File", ImGuiTableColumnFlags_WidthStretch);
8178
ImGui::TableSetupColumn("Input Neurons", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
8279
ImGui::TableSetupColumn("Output Neurons", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
8380
ImGui::TableSetupColumn("Layers", ImGuiTableColumnFlags_WidthFixed, TEXT_BASE_WIDTH * 12.0f);
@@ -91,8 +88,6 @@ namespace MachineLearning
9188
ImGui::TableNextColumn();
9289
ImGui::Text(neuralNetwork->GetName().c_str());
9390
ImGui::TableNextColumn();
94-
ImGui::Text(neuralNetwork->GetAssetFile(AssetTypes::Model).c_str());
95-
ImGui::TableNextColumn();
9691
ImGui::Text("%lld", aznumeric_cast<AZ::s64>(neuralNetwork->GetInputDimensionality()));
9792
ImGui::TableNextColumn();
9893
ImGui::Text("%lld", aznumeric_cast<AZ::s64>(neuralNetwork->GetOutputDimensionality()));

Gems/MachineLearning/Code/Source/Debug/MachineLearningDebugTrainingWindow.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ namespace MachineLearning
151151
| ImGuiTableFlags_RowBg
152152
| ImGuiTableFlags_NoBordersInBody;
153153

154-
const ImGuiTreeNodeFlags nodeFlags = (ImGuiTreeNodeFlags_Leaf | ImGuiTreeNodeFlags_NoTreePushOnOpen | ImGuiTreeNodeFlags_SpanFullWidth);
155-
156154
IMachineLearning* machineLearning = MachineLearningInterface::Get();
157155
const ModelSet& modelSet = machineLearning->GetModelSet();
158156

@@ -262,7 +260,6 @@ namespace MachineLearning
262260

263261
ImGui::NewLine();
264262
ImGui::Text("Model Name: %s", m_selectedModel->GetName().c_str());
265-
ImGui::Text("Asset location: %s", m_selectedModel->GetAssetFile(AssetTypes::Model).c_str());
266263

267264
if (ImGui::BeginTable("Accuracy", 2, flags))
268265
{

0 commit comments

Comments
 (0)