diff --git a/.gitignore b/.gitignore
index 4f8e77a3e..749832847 100644
--- a/.gitignore
+++ b/.gitignore
@@ -273,3 +273,7 @@ packages/
/.idea
/test/TorchSharpTest/exportsd.py
.vscode/settings.json
+/TestClear
+TestClear/
+/nuget.config
+/src/Native/LibTorchSharp/third_party
diff --git a/Directory.Build.props b/Directory.Build.props
index 256170d0a..1d1ea71f9 100644
--- a/Directory.Build.props
+++ b/Directory.Build.props
@@ -5,6 +5,7 @@
+
Debug
Debug;Release
<_DefaultArchitecture>$([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture.ToString().ToLower())
@@ -92,7 +93,6 @@
$(LibTorchPackageVersion)
-
true
@@ -164,8 +164,11 @@
$(DefineContants);DEBUG
false
+
+ $(DefineContants);CUDA_TOOLKIT_FOUND
+
true
-
+
\ No newline at end of file
diff --git a/nuget.config b/nuget.config
new file mode 100644
index 000000000..ef5d6f41e
--- /dev/null
+++ b/nuget.config
@@ -0,0 +1,4 @@
+
+
+ F:\NugetPackages
+
\ No newline at end of file
diff --git a/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.NupkgProj/FileRestitcher.Tests.csproj.nuget.dgspec.json b/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.NupkgProj/FileRestitcher.Tests.csproj.nuget.dgspec.json
new file mode 100644
index 000000000..0101447be
--- /dev/null
+++ b/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.NupkgProj/FileRestitcher.Tests.csproj.nuget.dgspec.json
@@ -0,0 +1,224 @@
+{
+ "format": 1,
+ "restore": {
+ "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher.Tests\\FileRestitcher.Tests.csproj": {}
+ },
+ "projects": {
+ "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher.Tests\\FileRestitcher.Tests.csproj": {
+ "version": "1.0.0",
+ "restore": {
+ "projectUniqueName": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher.Tests\\FileRestitcher.Tests.csproj",
+ "projectName": "FileRestitcher.Tests",
+ "projectPath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher.Tests\\FileRestitcher.Tests.csproj",
+ "packagesPath": "C:\\Users\\Dimitri\\.nuget\\packages\\",
+ "outputPath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher.Tests\\FileRestitcher.Tests.NupkgProj\\",
+ "projectStyle": "PackageReference",
+ "crossTargeting": true,
+ "fallbackFolders": [
+ "C:\\Program Files (x86)\\Microsoft Visual Studio\\Shared\\NuGetPackages"
+ ],
+ "configFilePaths": [
+ "K:\\Proyects_Repos\\TorchSharp\\NuGet.Config",
+ "C:\\Users\\Dimitri\\AppData\\Roaming\\NuGet\\NuGet.Config",
+ "C:\\Program Files (x86)\\NuGet\\Config\\Microsoft.VisualStudio.FallbackLocation.config",
+ "C:\\Program Files (x86)\\NuGet\\Config\\Microsoft.VisualStudio.Offline.config"
+ ],
+ "originalTargetFrameworks": [
+ "net472",
+ "netstandard2.0"
+ ],
+ "sources": {
+ "C:\\Program Files (x86)\\Microsoft SDKs\\NuGetPackages\\": {},
+ "https://api.nuget.org/v3/index.json": {}
+ },
+ "frameworks": {
+ "net472": {
+ "targetAlias": "net472",
+ "projectReferences": {
+ "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj": {
+ "projectPath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj"
+ }
+ }
+ },
+ "netstandard2.0": {
+ "targetAlias": "netstandard2.0",
+ "projectReferences": {
+ "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj": {
+ "projectPath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj"
+ }
+ }
+ }
+ },
+ "warningProperties": {
+ "warnAsError": [
+ "NU1605"
+ ]
+ },
+ "restoreAuditProperties": {
+ "enableAudit": "true",
+ "auditLevel": "low",
+ "auditMode": "all"
+ },
+ "SdkAnalysisLevel": "9.0.100"
+ },
+ "frameworks": {
+ "net472": {
+ "targetAlias": "net472",
+ "dependencies": {
+ "Microsoft.NET.Test.Sdk": {
+ "suppressParent": "None",
+ "target": "Package",
+ "version": "[16.9.4, )"
+ },
+ "coverlet.collector": {
+ "include": "Runtime, Build, Native, ContentFiles, Analyzers, BuildTransitive",
+ "suppressParent": "All",
+ "target": "Package",
+ "version": "[3.0.2, )"
+ },
+ "xunit": {
+ "suppressParent": "None",
+ "target": "Package",
+ "version": "[2.4.2, )"
+ }
+ },
+ "runtimeIdentifierGraphPath": "C:\\Program Files\\dotnet\\sdk\\9.0.100\\RuntimeIdentifierGraph.json"
+ },
+ "netstandard2.0": {
+ "targetAlias": "netstandard2.0",
+ "dependencies": {
+ "Microsoft.NET.Test.Sdk": {
+ "suppressParent": "None",
+ "target": "Package",
+ "version": "[16.9.4, )"
+ },
+ "NETStandard.Library": {
+ "suppressParent": "All",
+ "target": "Package",
+ "version": "[2.0.3, )",
+ "autoReferenced": true
+ },
+ "coverlet.collector": {
+ "include": "Runtime, Build, Native, ContentFiles, Analyzers, BuildTransitive",
+ "suppressParent": "All",
+ "target": "Package",
+ "version": "[3.0.2, )"
+ },
+ "xunit": {
+ "suppressParent": "None",
+ "target": "Package",
+ "version": "[2.4.2, )"
+ }
+ },
+ "imports": [
+ "net461",
+ "net462",
+ "net47",
+ "net471",
+ "net472",
+ "net48",
+ "net481"
+ ],
+ "assetTargetFallback": true,
+ "warn": true,
+ "runtimeIdentifierGraphPath": "C:\\Program Files\\dotnet\\sdk\\9.0.100\\RuntimeIdentifierGraph.json"
+ }
+ }
+ },
+ "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj": {
+ "version": "1.0.0",
+ "restore": {
+ "projectUniqueName": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj",
+ "projectName": "FileRestitcher",
+ "projectPath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj",
+ "packagesPath": "C:\\Users\\Dimitri\\.nuget\\packages\\",
+ "outputPath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.NupkgProj\\",
+ "projectStyle": "PackageReference",
+ "crossTargeting": true,
+ "fallbackFolders": [
+ "C:\\Program Files (x86)\\Microsoft Visual Studio\\Shared\\NuGetPackages"
+ ],
+ "configFilePaths": [
+ "K:\\Proyects_Repos\\TorchSharp\\NuGet.Config",
+ "C:\\Users\\Dimitri\\AppData\\Roaming\\NuGet\\NuGet.Config",
+ "C:\\Program Files (x86)\\NuGet\\Config\\Microsoft.VisualStudio.FallbackLocation.config",
+ "C:\\Program Files (x86)\\NuGet\\Config\\Microsoft.VisualStudio.Offline.config"
+ ],
+ "originalTargetFrameworks": [
+ "net6.0",
+ "netstandard2.0"
+ ],
+ "sources": {
+ "C:\\Program Files (x86)\\Microsoft SDKs\\NuGetPackages\\": {},
+ "https://api.nuget.org/v3/index.json": {}
+ },
+ "frameworks": {
+ "net6.0": {
+ "targetAlias": "net6.0",
+ "projectReferences": {}
+ },
+ "netstandard2.0": {
+ "targetAlias": "netstandard2.0",
+ "projectReferences": {}
+ }
+ },
+ "warningProperties": {
+ "warnAsError": [
+ "NU1605"
+ ]
+ },
+ "restoreAuditProperties": {
+ "enableAudit": "true",
+ "auditLevel": "low",
+ "auditMode": "all"
+ },
+ "SdkAnalysisLevel": "9.0.100"
+ },
+ "frameworks": {
+ "net6.0": {
+ "targetAlias": "net6.0",
+ "imports": [
+ "net461",
+ "net462",
+ "net47",
+ "net471",
+ "net472",
+ "net48",
+ "net481"
+ ],
+ "assetTargetFallback": true,
+ "warn": true,
+ "frameworkReferences": {
+ "Microsoft.NETCore.App": {
+ "privateAssets": "all"
+ }
+ },
+ "runtimeIdentifierGraphPath": "C:\\Program Files\\dotnet\\sdk\\9.0.100\\RuntimeIdentifierGraph.json"
+ },
+ "netstandard2.0": {
+ "targetAlias": "netstandard2.0",
+ "dependencies": {
+ "NETStandard.Library": {
+ "suppressParent": "All",
+ "target": "Package",
+ "version": "[2.0.3, )",
+ "autoReferenced": true
+ }
+ },
+ "imports": [
+ "net461",
+ "net462",
+ "net47",
+ "net471",
+ "net472",
+ "net48",
+ "net481"
+ ],
+ "assetTargetFallback": true,
+ "warn": true,
+ "runtimeIdentifierGraphPath": "C:\\Program Files\\dotnet\\sdk\\9.0.100\\RuntimeIdentifierGraph.json"
+ }
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.NupkgProj/FileRestitcher.Tests.csproj.nuget.g.props b/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.NupkgProj/FileRestitcher.Tests.csproj.nuget.g.props
new file mode 100644
index 000000000..7adfe6ee9
--- /dev/null
+++ b/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.NupkgProj/FileRestitcher.Tests.csproj.nuget.g.props
@@ -0,0 +1,35 @@
+
+
+
+ True
+ NuGet
+ $(MSBuildThisFileDirectory)project.assets.json
+ $(UserProfile)\.nuget\packages\
+ C:\Users\Dimitri\.nuget\packages\;C:\Program Files (x86)\Microsoft Visual Studio\Shared\NuGetPackages
+ PackageReference
+ 6.12.0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ C:\Users\Dimitri\.nuget\packages\xunit.analyzers\1.0.0
+
+
+ C:\Users\Dimitri\.nuget\packages\xunit.analyzers\1.0.0
+
+
\ No newline at end of file
diff --git a/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.NupkgProj/FileRestitcher.Tests.csproj.nuget.g.targets b/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.NupkgProj/FileRestitcher.Tests.csproj.nuget.g.targets
new file mode 100644
index 000000000..89347f8d0
--- /dev/null
+++ b/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.NupkgProj/FileRestitcher.Tests.csproj.nuget.g.targets
@@ -0,0 +1,18 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.NupkgProj/project.assets.json b/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.NupkgProj/project.assets.json
new file mode 100644
index 000000000..ac4726f8d
--- /dev/null
+++ b/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.NupkgProj/project.assets.json
@@ -0,0 +1,841 @@
+{
+ "version": 3,
+ "targets": {
+ ".NETFramework,Version=v4.7.2": {
+ "coverlet.collector/3.0.2": {
+ "type": "package",
+ "build": {
+ "build/netstandard1.0/coverlet.collector.targets": {}
+ }
+ },
+ "Microsoft.CodeCoverage/16.9.4": {
+ "type": "package",
+ "compile": {
+ "lib/net45/Microsoft.VisualStudio.CodeCoverage.Shim.dll": {}
+ },
+ "runtime": {
+ "lib/net45/Microsoft.VisualStudio.CodeCoverage.Shim.dll": {}
+ },
+ "build": {
+ "build/netstandard1.0/Microsoft.CodeCoverage.props": {},
+ "build/netstandard1.0/Microsoft.CodeCoverage.targets": {}
+ }
+ },
+ "Microsoft.NET.Test.Sdk/16.9.4": {
+ "type": "package",
+ "dependencies": {
+ "Microsoft.CodeCoverage": "16.9.4"
+ },
+ "compile": {
+ "lib/net45/_._": {}
+ },
+ "runtime": {
+ "lib/net45/_._": {}
+ },
+ "build": {
+ "build/net45/Microsoft.NET.Test.Sdk.props": {},
+ "build/net45/Microsoft.NET.Test.Sdk.targets": {}
+ },
+ "buildMultiTargeting": {
+ "buildMultiTargeting/Microsoft.NET.Test.Sdk.props": {}
+ }
+ },
+ "xunit/2.4.2": {
+ "type": "package",
+ "dependencies": {
+ "xunit.analyzers": "1.0.0",
+ "xunit.assert": "2.4.2",
+ "xunit.core": "[2.4.2]"
+ }
+ },
+ "xunit.abstractions/2.0.3": {
+ "type": "package",
+ "compile": {
+ "lib/net35/xunit.abstractions.dll": {
+ "related": ".xml"
+ }
+ },
+ "runtime": {
+ "lib/net35/xunit.abstractions.dll": {
+ "related": ".xml"
+ }
+ }
+ },
+ "xunit.analyzers/1.0.0": {
+ "type": "package"
+ },
+ "xunit.assert/2.4.2": {
+ "type": "package",
+ "compile": {
+ "lib/netstandard1.1/xunit.assert.dll": {
+ "related": ".xml"
+ }
+ },
+ "runtime": {
+ "lib/netstandard1.1/xunit.assert.dll": {
+ "related": ".xml"
+ }
+ }
+ },
+ "xunit.core/2.4.2": {
+ "type": "package",
+ "dependencies": {
+ "xunit.extensibility.core": "[2.4.2]",
+ "xunit.extensibility.execution": "[2.4.2]"
+ },
+ "build": {
+ "build/xunit.core.props": {},
+ "build/xunit.core.targets": {}
+ },
+ "buildMultiTargeting": {
+ "buildMultiTargeting/xunit.core.props": {},
+ "buildMultiTargeting/xunit.core.targets": {}
+ }
+ },
+ "xunit.extensibility.core/2.4.2": {
+ "type": "package",
+ "dependencies": {
+ "xunit.abstractions": "2.0.3"
+ },
+ "compile": {
+ "lib/net452/xunit.core.dll": {
+ "related": ".dll.tdnet;.xml"
+ }
+ },
+ "runtime": {
+ "lib/net452/xunit.core.dll": {
+ "related": ".dll.tdnet;.xml"
+ }
+ }
+ },
+ "xunit.extensibility.execution/2.4.2": {
+ "type": "package",
+ "dependencies": {
+ "xunit.extensibility.core": "[2.4.2]"
+ },
+ "compile": {
+ "lib/net452/xunit.execution.desktop.dll": {
+ "related": ".xml"
+ }
+ },
+ "runtime": {
+ "lib/net452/xunit.execution.desktop.dll": {
+ "related": ".xml"
+ }
+ }
+ },
+ "FileRestitcher/1.0.0": {
+ "type": "project",
+ "framework": ".NETStandard,Version=v2.0",
+ "compile": {
+ "bin/placeholder/FileRestitcher.dll": {}
+ },
+ "runtime": {
+ "bin/placeholder/FileRestitcher.dll": {}
+ }
+ }
+ },
+ ".NETStandard,Version=v2.0": {
+ "coverlet.collector/3.0.2": {
+ "type": "package",
+ "build": {
+ "build/netstandard1.0/coverlet.collector.targets": {}
+ }
+ },
+ "Microsoft.CodeCoverage/16.9.4": {
+ "type": "package",
+ "build": {
+ "build/netstandard1.0/Microsoft.CodeCoverage.props": {},
+ "build/netstandard1.0/Microsoft.CodeCoverage.targets": {}
+ }
+ },
+ "Microsoft.NET.Test.Sdk/16.9.4": {
+ "type": "package",
+ "dependencies": {
+ "Microsoft.CodeCoverage": "16.9.4"
+ },
+ "buildMultiTargeting": {
+ "buildMultiTargeting/Microsoft.NET.Test.Sdk.props": {}
+ }
+ },
+ "Microsoft.NETCore.Platforms/1.1.0": {
+ "type": "package",
+ "compile": {
+ "lib/netstandard1.0/_._": {}
+ },
+ "runtime": {
+ "lib/netstandard1.0/_._": {}
+ }
+ },
+ "NETStandard.Library/2.0.3": {
+ "type": "package",
+ "dependencies": {
+ "Microsoft.NETCore.Platforms": "1.1.0"
+ },
+ "compile": {
+ "lib/netstandard1.0/_._": {}
+ },
+ "runtime": {
+ "lib/netstandard1.0/_._": {}
+ },
+ "build": {
+ "build/netstandard2.0/NETStandard.Library.targets": {}
+ }
+ },
+ "xunit/2.4.2": {
+ "type": "package",
+ "dependencies": {
+ "xunit.analyzers": "1.0.0",
+ "xunit.assert": "2.4.2",
+ "xunit.core": "[2.4.2]"
+ }
+ },
+ "xunit.abstractions/2.0.3": {
+ "type": "package",
+ "compile": {
+ "lib/netstandard2.0/xunit.abstractions.dll": {
+ "related": ".xml"
+ }
+ },
+ "runtime": {
+ "lib/netstandard2.0/xunit.abstractions.dll": {
+ "related": ".xml"
+ }
+ }
+ },
+ "xunit.analyzers/1.0.0": {
+ "type": "package"
+ },
+ "xunit.assert/2.4.2": {
+ "type": "package",
+ "dependencies": {
+ "NETStandard.Library": "1.6.1"
+ },
+ "compile": {
+ "lib/netstandard1.1/xunit.assert.dll": {
+ "related": ".xml"
+ }
+ },
+ "runtime": {
+ "lib/netstandard1.1/xunit.assert.dll": {
+ "related": ".xml"
+ }
+ }
+ },
+ "xunit.core/2.4.2": {
+ "type": "package",
+ "dependencies": {
+ "xunit.extensibility.core": "[2.4.2]",
+ "xunit.extensibility.execution": "[2.4.2]"
+ },
+ "build": {
+ "build/xunit.core.props": {},
+ "build/xunit.core.targets": {}
+ },
+ "buildMultiTargeting": {
+ "buildMultiTargeting/xunit.core.props": {},
+ "buildMultiTargeting/xunit.core.targets": {}
+ }
+ },
+ "xunit.extensibility.core/2.4.2": {
+ "type": "package",
+ "dependencies": {
+ "NETStandard.Library": "1.6.1",
+ "xunit.abstractions": "2.0.3"
+ },
+ "compile": {
+ "lib/netstandard1.1/xunit.core.dll": {
+ "related": ".xml"
+ }
+ },
+ "runtime": {
+ "lib/netstandard1.1/xunit.core.dll": {
+ "related": ".xml"
+ }
+ }
+ },
+ "xunit.extensibility.execution/2.4.2": {
+ "type": "package",
+ "dependencies": {
+ "NETStandard.Library": "1.6.1",
+ "xunit.extensibility.core": "[2.4.2]"
+ },
+ "compile": {
+ "lib/netstandard1.1/xunit.execution.dotnet.dll": {
+ "related": ".xml"
+ }
+ },
+ "runtime": {
+ "lib/netstandard1.1/xunit.execution.dotnet.dll": {
+ "related": ".xml"
+ }
+ }
+ },
+ "FileRestitcher/1.0.0": {
+ "type": "project",
+ "framework": ".NETStandard,Version=v2.0",
+ "compile": {
+ "bin/placeholder/FileRestitcher.dll": {}
+ },
+ "runtime": {
+ "bin/placeholder/FileRestitcher.dll": {}
+ }
+ }
+ }
+ },
+ "libraries": {
+ "coverlet.collector/3.0.2": {
+ "sha512": "iBvPAIDaI7j/iMx/DzCGCJ3rdiOmel9VINEfaTiBv/NKIGHOP4X3hqc6Q1wgMtArEshlhXexQknP17SK4vXb1w==",
+ "type": "package",
+ "path": "coverlet.collector/3.0.2",
+ "files": [
+ ".nupkg.metadata",
+ ".signature.p7s",
+ "build/netstandard1.0/Microsoft.CSharp.dll",
+ "build/netstandard1.0/Microsoft.DotNet.PlatformAbstractions.dll",
+ "build/netstandard1.0/Microsoft.Extensions.DependencyInjection.Abstractions.dll",
+ "build/netstandard1.0/Microsoft.Extensions.DependencyInjection.dll",
+ "build/netstandard1.0/Microsoft.Extensions.DependencyModel.dll",
+ "build/netstandard1.0/Microsoft.Extensions.FileSystemGlobbing.dll",
+ "build/netstandard1.0/Microsoft.TestPlatform.CoreUtilities.dll",
+ "build/netstandard1.0/Microsoft.TestPlatform.PlatformAbstractions.dll",
+ "build/netstandard1.0/Microsoft.VisualStudio.TestPlatform.ObjectModel.dll",
+ "build/netstandard1.0/Mono.Cecil.Mdb.dll",
+ "build/netstandard1.0/Mono.Cecil.Pdb.dll",
+ "build/netstandard1.0/Mono.Cecil.Rocks.dll",
+ "build/netstandard1.0/Mono.Cecil.dll",
+ "build/netstandard1.0/Newtonsoft.Json.dll",
+ "build/netstandard1.0/NuGet.Frameworks.dll",
+ "build/netstandard1.0/System.AppContext.dll",
+ "build/netstandard1.0/System.Collections.Immutable.dll",
+ "build/netstandard1.0/System.Dynamic.Runtime.dll",
+ "build/netstandard1.0/System.IO.FileSystem.Primitives.dll",
+ "build/netstandard1.0/System.Linq.Expressions.dll",
+ "build/netstandard1.0/System.Linq.dll",
+ "build/netstandard1.0/System.ObjectModel.dll",
+ "build/netstandard1.0/System.Reflection.Emit.ILGeneration.dll",
+ "build/netstandard1.0/System.Reflection.Emit.Lightweight.dll",
+ "build/netstandard1.0/System.Reflection.Emit.dll",
+ "build/netstandard1.0/System.Reflection.Metadata.dll",
+ "build/netstandard1.0/System.Reflection.TypeExtensions.dll",
+ "build/netstandard1.0/System.Runtime.Serialization.Primitives.dll",
+ "build/netstandard1.0/System.Text.RegularExpressions.dll",
+ "build/netstandard1.0/System.Threading.Tasks.Extensions.dll",
+ "build/netstandard1.0/System.Threading.dll",
+ "build/netstandard1.0/System.Xml.ReaderWriter.dll",
+ "build/netstandard1.0/System.Xml.XDocument.dll",
+ "build/netstandard1.0/coverlet.collector.deps.json",
+ "build/netstandard1.0/coverlet.collector.dll",
+ "build/netstandard1.0/coverlet.collector.pdb",
+ "build/netstandard1.0/coverlet.collector.targets",
+ "build/netstandard1.0/coverlet.core.dll",
+ "build/netstandard1.0/coverlet.core.pdb",
+ "coverlet-icon.png",
+ "coverlet.collector.3.0.2.nupkg.sha512",
+ "coverlet.collector.nuspec"
+ ]
+ },
+ "Microsoft.CodeCoverage/16.9.4": {
+ "sha512": "N/RYB07gJkPZ1nJiq0QGxFIL+X5vVl4GI99PiTYXpbfI30NTZMRJgZ+4jYLFYLDQqj9o1Juhv+3iiymd7lozrA==",
+ "type": "package",
+ "path": "microsoft.codecoverage/16.9.4",
+ "files": [
+ ".nupkg.metadata",
+ ".signature.p7s",
+ "Icon.png",
+ "LICENSE_NET.txt",
+ "build/netstandard1.0/CodeCoverage/CodeCoverage.config",
+ "build/netstandard1.0/CodeCoverage/CodeCoverage.exe",
+ "build/netstandard1.0/CodeCoverage/VanguardInstrumentationProfiler_x86.config",
+ "build/netstandard1.0/CodeCoverage/amd64/CodeCoverage.exe",
+ "build/netstandard1.0/CodeCoverage/amd64/VanguardInstrumentationProfiler_x64.config",
+ "build/netstandard1.0/CodeCoverage/amd64/covrun64.dll",
+ "build/netstandard1.0/CodeCoverage/amd64/msdia140.dll",
+ "build/netstandard1.0/CodeCoverage/amd64/msvcdis140.dll",
+ "build/netstandard1.0/CodeCoverage/amd64/msvcp140.dll",
+ "build/netstandard1.0/CodeCoverage/amd64/msvcp140_atomic_wait.dll",
+ "build/netstandard1.0/CodeCoverage/amd64/vcruntime140.dll",
+ "build/netstandard1.0/CodeCoverage/amd64/vcruntime140_1.dll",
+ "build/netstandard1.0/CodeCoverage/codecoveragemessages.dll",
+ "build/netstandard1.0/CodeCoverage/coreclr/Microsoft.VisualStudio.CodeCoverage.Shim.dll",
+ "build/netstandard1.0/CodeCoverage/covrun32.dll",
+ "build/netstandard1.0/CodeCoverage/msdia140.dll",
+ "build/netstandard1.0/CodeCoverage/msvcdis140.dll",
+ "build/netstandard1.0/CodeCoverage/msvcp140.dll",
+ "build/netstandard1.0/CodeCoverage/msvcp140_atomic_wait.dll",
+ "build/netstandard1.0/CodeCoverage/vcruntime140.dll",
+ "build/netstandard1.0/InstrumentationEngine/x64/MicrosoftInstrumentationEngine_x64.dll",
+ "build/netstandard1.0/InstrumentationEngine/x86/MicrosoftInstrumentationEngine_x86.dll",
+ "build/netstandard1.0/Microsoft.CodeCoverage.props",
+ "build/netstandard1.0/Microsoft.CodeCoverage.targets",
+ "build/netstandard1.0/Microsoft.VisualStudio.Coverage.CoreLib.Net.dll",
+ "build/netstandard1.0/Microsoft.VisualStudio.Coverage.Interprocess.dll",
+ "build/netstandard1.0/Microsoft.VisualStudio.TraceDataCollector.dll",
+ "build/netstandard1.0/cs/Microsoft.VisualStudio.Coverage.CoreLib.Net.resources.dll",
+ "build/netstandard1.0/cs/Microsoft.VisualStudio.TraceDataCollector.resources.dll",
+ "build/netstandard1.0/de/Microsoft.VisualStudio.Coverage.CoreLib.Net.resources.dll",
+ "build/netstandard1.0/de/Microsoft.VisualStudio.TraceDataCollector.resources.dll",
+ "build/netstandard1.0/es/Microsoft.VisualStudio.Coverage.CoreLib.Net.resources.dll",
+ "build/netstandard1.0/es/Microsoft.VisualStudio.TraceDataCollector.resources.dll",
+ "build/netstandard1.0/fr/Microsoft.VisualStudio.Coverage.CoreLib.Net.resources.dll",
+ "build/netstandard1.0/fr/Microsoft.VisualStudio.TraceDataCollector.resources.dll",
+ "build/netstandard1.0/it/Microsoft.VisualStudio.Coverage.CoreLib.Net.resources.dll",
+ "build/netstandard1.0/it/Microsoft.VisualStudio.TraceDataCollector.resources.dll",
+ "build/netstandard1.0/ja/Microsoft.VisualStudio.Coverage.CoreLib.Net.resources.dll",
+ "build/netstandard1.0/ja/Microsoft.VisualStudio.TraceDataCollector.resources.dll",
+ "build/netstandard1.0/ko/Microsoft.VisualStudio.Coverage.CoreLib.Net.resources.dll",
+ "build/netstandard1.0/ko/Microsoft.VisualStudio.TraceDataCollector.resources.dll",
+ "build/netstandard1.0/pl/Microsoft.VisualStudio.Coverage.CoreLib.Net.resources.dll",
+ "build/netstandard1.0/pl/Microsoft.VisualStudio.TraceDataCollector.resources.dll",
+ "build/netstandard1.0/pt-BR/Microsoft.VisualStudio.Coverage.CoreLib.Net.resources.dll",
+ "build/netstandard1.0/pt-BR/Microsoft.VisualStudio.TraceDataCollector.resources.dll",
+ "build/netstandard1.0/ru/Microsoft.VisualStudio.Coverage.CoreLib.Net.resources.dll",
+ "build/netstandard1.0/ru/Microsoft.VisualStudio.TraceDataCollector.resources.dll",
+ "build/netstandard1.0/tr/Microsoft.VisualStudio.Coverage.CoreLib.Net.resources.dll",
+ "build/netstandard1.0/tr/Microsoft.VisualStudio.TraceDataCollector.resources.dll",
+ "build/netstandard1.0/zh-Hans/Microsoft.VisualStudio.Coverage.CoreLib.Net.resources.dll",
+ "build/netstandard1.0/zh-Hans/Microsoft.VisualStudio.TraceDataCollector.resources.dll",
+ "build/netstandard1.0/zh-Hant/Microsoft.VisualStudio.Coverage.CoreLib.Net.resources.dll",
+ "build/netstandard1.0/zh-Hant/Microsoft.VisualStudio.TraceDataCollector.resources.dll",
+ "lib/net45/Microsoft.VisualStudio.CodeCoverage.Shim.dll",
+ "lib/netcoreapp1.0/Microsoft.VisualStudio.CodeCoverage.Shim.dll",
+ "microsoft.codecoverage.16.9.4.nupkg.sha512",
+ "microsoft.codecoverage.nuspec"
+ ]
+ },
+ "Microsoft.NET.Test.Sdk/16.9.4": {
+ "sha512": "M/k16vmS7Hz/+Kuy3p6XE743XPjYYMzfN5ZvpSLY44Ngh5IBMk0Je5Qed8oq6/kvzJA2DTrXa7YrfceHhbQKeQ==",
+ "type": "package",
+ "path": "microsoft.net.test.sdk/16.9.4",
+ "files": [
+ ".nupkg.metadata",
+ ".signature.p7s",
+ "Icon.png",
+ "LICENSE_NET.txt",
+ "build/net40/Microsoft.NET.Test.Sdk.props",
+ "build/net40/Microsoft.NET.Test.Sdk.targets",
+ "build/net45/Microsoft.NET.Test.Sdk.props",
+ "build/net45/Microsoft.NET.Test.Sdk.targets",
+ "build/netcoreapp1.0/Microsoft.NET.Test.Sdk.Program.cs",
+ "build/netcoreapp1.0/Microsoft.NET.Test.Sdk.Program.fs",
+ "build/netcoreapp1.0/Microsoft.NET.Test.Sdk.Program.vb",
+ "build/netcoreapp1.0/Microsoft.NET.Test.Sdk.props",
+ "build/netcoreapp1.0/Microsoft.NET.Test.Sdk.targets",
+ "build/netcoreapp2.1/Microsoft.NET.Test.Sdk.Program.cs",
+ "build/netcoreapp2.1/Microsoft.NET.Test.Sdk.Program.fs",
+ "build/netcoreapp2.1/Microsoft.NET.Test.Sdk.Program.vb",
+ "build/netcoreapp2.1/Microsoft.NET.Test.Sdk.props",
+ "build/netcoreapp2.1/Microsoft.NET.Test.Sdk.targets",
+ "build/uap10.0/Microsoft.NET.Test.Sdk.props",
+ "buildMultiTargeting/Microsoft.NET.Test.Sdk.props",
+ "lib/net40/_._",
+ "lib/net45/_._",
+ "lib/netcoreapp1.0/_._",
+ "lib/netcoreapp2.1/_._",
+ "lib/uap10.0/_._",
+ "microsoft.net.test.sdk.16.9.4.nupkg.sha512",
+ "microsoft.net.test.sdk.nuspec"
+ ]
+ },
+ "Microsoft.NETCore.Platforms/1.1.0": {
+ "sha512": "kz0PEW2lhqygehI/d6XsPCQzD7ff7gUJaVGPVETX611eadGsA3A877GdSlU0LRVMCTH/+P3o2iDTak+S08V2+A==",
+ "type": "package",
+ "path": "microsoft.netcore.platforms/1.1.0",
+ "files": [
+ ".nupkg.metadata",
+ ".signature.p7s",
+ "ThirdPartyNotices.txt",
+ "dotnet_library_license.txt",
+ "lib/netstandard1.0/_._",
+ "microsoft.netcore.platforms.1.1.0.nupkg.sha512",
+ "microsoft.netcore.platforms.nuspec",
+ "runtime.json"
+ ]
+ },
+ "NETStandard.Library/2.0.3": {
+ "sha512": "st47PosZSHrjECdjeIzZQbzivYBJFv6P2nv4cj2ypdI204DO+vZ7l5raGMiX4eXMJ53RfOIg+/s4DHVZ54Nu2A==",
+ "type": "package",
+ "path": "netstandard.library/2.0.3",
+ "files": [
+ ".nupkg.metadata",
+ ".signature.p7s",
+ "LICENSE.TXT",
+ "THIRD-PARTY-NOTICES.TXT",
+ "build/netstandard2.0/NETStandard.Library.targets",
+ "build/netstandard2.0/ref/Microsoft.Win32.Primitives.dll",
+ "build/netstandard2.0/ref/System.AppContext.dll",
+ "build/netstandard2.0/ref/System.Collections.Concurrent.dll",
+ "build/netstandard2.0/ref/System.Collections.NonGeneric.dll",
+ "build/netstandard2.0/ref/System.Collections.Specialized.dll",
+ "build/netstandard2.0/ref/System.Collections.dll",
+ "build/netstandard2.0/ref/System.ComponentModel.Composition.dll",
+ "build/netstandard2.0/ref/System.ComponentModel.EventBasedAsync.dll",
+ "build/netstandard2.0/ref/System.ComponentModel.Primitives.dll",
+ "build/netstandard2.0/ref/System.ComponentModel.TypeConverter.dll",
+ "build/netstandard2.0/ref/System.ComponentModel.dll",
+ "build/netstandard2.0/ref/System.Console.dll",
+ "build/netstandard2.0/ref/System.Core.dll",
+ "build/netstandard2.0/ref/System.Data.Common.dll",
+ "build/netstandard2.0/ref/System.Data.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.Contracts.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.Debug.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.FileVersionInfo.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.Process.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.StackTrace.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.TextWriterTraceListener.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.Tools.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.TraceSource.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.Tracing.dll",
+ "build/netstandard2.0/ref/System.Drawing.Primitives.dll",
+ "build/netstandard2.0/ref/System.Drawing.dll",
+ "build/netstandard2.0/ref/System.Dynamic.Runtime.dll",
+ "build/netstandard2.0/ref/System.Globalization.Calendars.dll",
+ "build/netstandard2.0/ref/System.Globalization.Extensions.dll",
+ "build/netstandard2.0/ref/System.Globalization.dll",
+ "build/netstandard2.0/ref/System.IO.Compression.FileSystem.dll",
+ "build/netstandard2.0/ref/System.IO.Compression.ZipFile.dll",
+ "build/netstandard2.0/ref/System.IO.Compression.dll",
+ "build/netstandard2.0/ref/System.IO.FileSystem.DriveInfo.dll",
+ "build/netstandard2.0/ref/System.IO.FileSystem.Primitives.dll",
+ "build/netstandard2.0/ref/System.IO.FileSystem.Watcher.dll",
+ "build/netstandard2.0/ref/System.IO.FileSystem.dll",
+ "build/netstandard2.0/ref/System.IO.IsolatedStorage.dll",
+ "build/netstandard2.0/ref/System.IO.MemoryMappedFiles.dll",
+ "build/netstandard2.0/ref/System.IO.Pipes.dll",
+ "build/netstandard2.0/ref/System.IO.UnmanagedMemoryStream.dll",
+ "build/netstandard2.0/ref/System.IO.dll",
+ "build/netstandard2.0/ref/System.Linq.Expressions.dll",
+ "build/netstandard2.0/ref/System.Linq.Parallel.dll",
+ "build/netstandard2.0/ref/System.Linq.Queryable.dll",
+ "build/netstandard2.0/ref/System.Linq.dll",
+ "build/netstandard2.0/ref/System.Net.Http.dll",
+ "build/netstandard2.0/ref/System.Net.NameResolution.dll",
+ "build/netstandard2.0/ref/System.Net.NetworkInformation.dll",
+ "build/netstandard2.0/ref/System.Net.Ping.dll",
+ "build/netstandard2.0/ref/System.Net.Primitives.dll",
+ "build/netstandard2.0/ref/System.Net.Requests.dll",
+ "build/netstandard2.0/ref/System.Net.Security.dll",
+ "build/netstandard2.0/ref/System.Net.Sockets.dll",
+ "build/netstandard2.0/ref/System.Net.WebHeaderCollection.dll",
+ "build/netstandard2.0/ref/System.Net.WebSockets.Client.dll",
+ "build/netstandard2.0/ref/System.Net.WebSockets.dll",
+ "build/netstandard2.0/ref/System.Net.dll",
+ "build/netstandard2.0/ref/System.Numerics.dll",
+ "build/netstandard2.0/ref/System.ObjectModel.dll",
+ "build/netstandard2.0/ref/System.Reflection.Extensions.dll",
+ "build/netstandard2.0/ref/System.Reflection.Primitives.dll",
+ "build/netstandard2.0/ref/System.Reflection.dll",
+ "build/netstandard2.0/ref/System.Resources.Reader.dll",
+ "build/netstandard2.0/ref/System.Resources.ResourceManager.dll",
+ "build/netstandard2.0/ref/System.Resources.Writer.dll",
+ "build/netstandard2.0/ref/System.Runtime.CompilerServices.VisualC.dll",
+ "build/netstandard2.0/ref/System.Runtime.Extensions.dll",
+ "build/netstandard2.0/ref/System.Runtime.Handles.dll",
+ "build/netstandard2.0/ref/System.Runtime.InteropServices.RuntimeInformation.dll",
+ "build/netstandard2.0/ref/System.Runtime.InteropServices.dll",
+ "build/netstandard2.0/ref/System.Runtime.Numerics.dll",
+ "build/netstandard2.0/ref/System.Runtime.Serialization.Formatters.dll",
+ "build/netstandard2.0/ref/System.Runtime.Serialization.Json.dll",
+ "build/netstandard2.0/ref/System.Runtime.Serialization.Primitives.dll",
+ "build/netstandard2.0/ref/System.Runtime.Serialization.Xml.dll",
+ "build/netstandard2.0/ref/System.Runtime.Serialization.dll",
+ "build/netstandard2.0/ref/System.Runtime.dll",
+ "build/netstandard2.0/ref/System.Security.Claims.dll",
+ "build/netstandard2.0/ref/System.Security.Cryptography.Algorithms.dll",
+ "build/netstandard2.0/ref/System.Security.Cryptography.Csp.dll",
+ "build/netstandard2.0/ref/System.Security.Cryptography.Encoding.dll",
+ "build/netstandard2.0/ref/System.Security.Cryptography.Primitives.dll",
+ "build/netstandard2.0/ref/System.Security.Cryptography.X509Certificates.dll",
+ "build/netstandard2.0/ref/System.Security.Principal.dll",
+ "build/netstandard2.0/ref/System.Security.SecureString.dll",
+ "build/netstandard2.0/ref/System.ServiceModel.Web.dll",
+ "build/netstandard2.0/ref/System.Text.Encoding.Extensions.dll",
+ "build/netstandard2.0/ref/System.Text.Encoding.dll",
+ "build/netstandard2.0/ref/System.Text.RegularExpressions.dll",
+ "build/netstandard2.0/ref/System.Threading.Overlapped.dll",
+ "build/netstandard2.0/ref/System.Threading.Tasks.Parallel.dll",
+ "build/netstandard2.0/ref/System.Threading.Tasks.dll",
+ "build/netstandard2.0/ref/System.Threading.Thread.dll",
+ "build/netstandard2.0/ref/System.Threading.ThreadPool.dll",
+ "build/netstandard2.0/ref/System.Threading.Timer.dll",
+ "build/netstandard2.0/ref/System.Threading.dll",
+ "build/netstandard2.0/ref/System.Transactions.dll",
+ "build/netstandard2.0/ref/System.ValueTuple.dll",
+ "build/netstandard2.0/ref/System.Web.dll",
+ "build/netstandard2.0/ref/System.Windows.dll",
+ "build/netstandard2.0/ref/System.Xml.Linq.dll",
+ "build/netstandard2.0/ref/System.Xml.ReaderWriter.dll",
+ "build/netstandard2.0/ref/System.Xml.Serialization.dll",
+ "build/netstandard2.0/ref/System.Xml.XDocument.dll",
+ "build/netstandard2.0/ref/System.Xml.XPath.XDocument.dll",
+ "build/netstandard2.0/ref/System.Xml.XPath.dll",
+ "build/netstandard2.0/ref/System.Xml.XmlDocument.dll",
+ "build/netstandard2.0/ref/System.Xml.XmlSerializer.dll",
+ "build/netstandard2.0/ref/System.Xml.dll",
+ "build/netstandard2.0/ref/System.dll",
+ "build/netstandard2.0/ref/mscorlib.dll",
+ "build/netstandard2.0/ref/netstandard.dll",
+ "build/netstandard2.0/ref/netstandard.xml",
+ "lib/netstandard1.0/_._",
+ "netstandard.library.2.0.3.nupkg.sha512",
+ "netstandard.library.nuspec"
+ ]
+ },
+ "xunit/2.4.2": {
+ "sha512": "6Mj73Ont3zj2CJuoykVJfE0ZmRwn7C+pTuRP8c4bnaaTFjwNG6tGe0prJ1yIbMe9AHrpDys63ctWacSsFJWK/w==",
+ "type": "package",
+ "path": "xunit/2.4.2",
+ "files": [
+ ".nupkg.metadata",
+ ".signature.p7s",
+ "_content/logo-128-transparent.png",
+ "xunit.2.4.2.nupkg.sha512",
+ "xunit.nuspec"
+ ]
+ },
+ "xunit.abstractions/2.0.3": {
+ "sha512": "pot1I4YOxlWjIb5jmwvvQNbTrZ3lJQ+jUGkGjWE3hEFM0l5gOnBWS+H3qsex68s5cO52g+44vpGzhAt+42vwKg==",
+ "type": "package",
+ "path": "xunit.abstractions/2.0.3",
+ "files": [
+ ".nupkg.metadata",
+ ".signature.p7s",
+ "lib/net35/xunit.abstractions.dll",
+ "lib/net35/xunit.abstractions.xml",
+ "lib/netstandard1.0/xunit.abstractions.dll",
+ "lib/netstandard1.0/xunit.abstractions.xml",
+ "lib/netstandard2.0/xunit.abstractions.dll",
+ "lib/netstandard2.0/xunit.abstractions.xml",
+ "xunit.abstractions.2.0.3.nupkg.sha512",
+ "xunit.abstractions.nuspec"
+ ]
+ },
+ "xunit.analyzers/1.0.0": {
+ "sha512": "BeO8hEgs/c8Ls2647fPfieMngncvf0D0xYNDfIO59MolxtCtVjFRd6SRc+7tj8VMqkVOuJcnc9eh4ngI2cAmLQ==",
+ "type": "package",
+ "path": "xunit.analyzers/1.0.0",
+ "hasTools": true,
+ "files": [
+ ".nupkg.metadata",
+ ".signature.p7s",
+ "_content/logo-128-transparent.png",
+ "analyzers/dotnet/cs/xunit.analyzers.dll",
+ "analyzers/dotnet/cs/xunit.analyzers.fixes.dll",
+ "tools/install.ps1",
+ "tools/uninstall.ps1",
+ "xunit.analyzers.1.0.0.nupkg.sha512",
+ "xunit.analyzers.nuspec"
+ ]
+ },
+ "xunit.assert/2.4.2": {
+ "sha512": "pxJISOFjn2XTTi1mcDCkRZrTFb9OtRRCtx2kZFNF51GdReLr1ls2rnyxvAS4JO247K3aNtflvh5Q0346K5BROA==",
+ "type": "package",
+ "path": "xunit.assert/2.4.2",
+ "files": [
+ ".nupkg.metadata",
+ ".signature.p7s",
+ "_content/logo-128-transparent.png",
+ "lib/netstandard1.1/xunit.assert.dll",
+ "lib/netstandard1.1/xunit.assert.xml",
+ "xunit.assert.2.4.2.nupkg.sha512",
+ "xunit.assert.nuspec"
+ ]
+ },
+ "xunit.core/2.4.2": {
+ "sha512": "KB4yGCxNqIVyekhJLXtKSEq6BaXVp/JO3mbGVE1hxypZTLEe7h+sTbAhpA+yZW2dPtXTuiW+C1B2oxxHEkrmOw==",
+ "type": "package",
+ "path": "xunit.core/2.4.2",
+ "files": [
+ ".nupkg.metadata",
+ ".signature.p7s",
+ "_content/logo-128-transparent.png",
+ "build/xunit.core.props",
+ "build/xunit.core.targets",
+ "buildMultiTargeting/xunit.core.props",
+ "buildMultiTargeting/xunit.core.targets",
+ "xunit.core.2.4.2.nupkg.sha512",
+ "xunit.core.nuspec"
+ ]
+ },
+ "xunit.extensibility.core/2.4.2": {
+ "sha512": "W1BoXTIN1C6kpVSMw25huSet25ky6IAQUNovu3zGOGN/jWnbgSoTyCrlIhmXSg0tH5nEf8q7h3OjNHOjyu5PfA==",
+ "type": "package",
+ "path": "xunit.extensibility.core/2.4.2",
+ "files": [
+ ".nupkg.metadata",
+ ".signature.p7s",
+ "_content/logo-128-transparent.png",
+ "lib/net452/xunit.core.dll",
+ "lib/net452/xunit.core.dll.tdnet",
+ "lib/net452/xunit.core.xml",
+ "lib/net452/xunit.runner.tdnet.dll",
+ "lib/net452/xunit.runner.utility.net452.dll",
+ "lib/netstandard1.1/xunit.core.dll",
+ "lib/netstandard1.1/xunit.core.xml",
+ "xunit.extensibility.core.2.4.2.nupkg.sha512",
+ "xunit.extensibility.core.nuspec"
+ ]
+ },
+ "xunit.extensibility.execution/2.4.2": {
+ "sha512": "CZmgcKkwpyo8FlupZdWpJCryrAOWLh1FBPG6gmVZuPQkGQsim/oL4PcP4nfrC2hHgXUFtluvaJ0Sp9PQKUMNpg==",
+ "type": "package",
+ "path": "xunit.extensibility.execution/2.4.2",
+ "files": [
+ ".nupkg.metadata",
+ ".signature.p7s",
+ "_content/logo-128-transparent.png",
+ "lib/net452/xunit.execution.desktop.dll",
+ "lib/net452/xunit.execution.desktop.xml",
+ "lib/netstandard1.1/xunit.execution.dotnet.dll",
+ "lib/netstandard1.1/xunit.execution.dotnet.xml",
+ "xunit.extensibility.execution.2.4.2.nupkg.sha512",
+ "xunit.extensibility.execution.nuspec"
+ ]
+ },
+ "FileRestitcher/1.0.0": {
+ "type": "project",
+ "path": "../FileRestitcher/FileRestitcher.csproj",
+ "msbuildProject": "../FileRestitcher/FileRestitcher.csproj"
+ }
+ },
+ "projectFileDependencyGroups": {
+ ".NETFramework,Version=v4.7.2": [
+ "FileRestitcher >= 1.0.0",
+ "Microsoft.NET.Test.Sdk >= 16.9.4",
+ "coverlet.collector >= 3.0.2",
+ "xunit >= 2.4.2"
+ ],
+ ".NETStandard,Version=v2.0": [
+ "FileRestitcher >= 1.0.0",
+ "Microsoft.NET.Test.Sdk >= 16.9.4",
+ "NETStandard.Library >= 2.0.3",
+ "coverlet.collector >= 3.0.2",
+ "xunit >= 2.4.2"
+ ]
+ },
+ "packageFolders": {
+ "C:\\Users\\Dimitri\\.nuget\\packages\\": {},
+ "C:\\Program Files (x86)\\Microsoft Visual Studio\\Shared\\NuGetPackages": {}
+ },
+ "project": {
+ "version": "1.0.0",
+ "restore": {
+ "projectUniqueName": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher.Tests\\FileRestitcher.Tests.csproj",
+ "projectName": "FileRestitcher.Tests",
+ "projectPath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher.Tests\\FileRestitcher.Tests.csproj",
+ "packagesPath": "C:\\Users\\Dimitri\\.nuget\\packages\\",
+ "outputPath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher.Tests\\FileRestitcher.Tests.NupkgProj\\",
+ "projectStyle": "PackageReference",
+ "crossTargeting": true,
+ "fallbackFolders": [
+ "C:\\Program Files (x86)\\Microsoft Visual Studio\\Shared\\NuGetPackages"
+ ],
+ "configFilePaths": [
+ "K:\\Proyects_Repos\\TorchSharp\\NuGet.Config",
+ "C:\\Users\\Dimitri\\AppData\\Roaming\\NuGet\\NuGet.Config",
+ "C:\\Program Files (x86)\\NuGet\\Config\\Microsoft.VisualStudio.FallbackLocation.config",
+ "C:\\Program Files (x86)\\NuGet\\Config\\Microsoft.VisualStudio.Offline.config"
+ ],
+ "originalTargetFrameworks": [
+ "net472",
+ "netstandard2.0"
+ ],
+ "sources": {
+ "C:\\Program Files (x86)\\Microsoft SDKs\\NuGetPackages\\": {},
+ "https://api.nuget.org/v3/index.json": {}
+ },
+ "frameworks": {
+ "net472": {
+ "targetAlias": "net472",
+ "projectReferences": {
+ "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj": {
+ "projectPath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj"
+ }
+ }
+ },
+ "netstandard2.0": {
+ "targetAlias": "netstandard2.0",
+ "projectReferences": {
+ "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj": {
+ "projectPath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj"
+ }
+ }
+ }
+ },
+ "warningProperties": {
+ "warnAsError": [
+ "NU1605"
+ ]
+ },
+ "restoreAuditProperties": {
+ "enableAudit": "true",
+ "auditLevel": "low",
+ "auditMode": "all"
+ },
+ "SdkAnalysisLevel": "9.0.100"
+ },
+ "frameworks": {
+ "net472": {
+ "targetAlias": "net472",
+ "dependencies": {
+ "Microsoft.NET.Test.Sdk": {
+ "suppressParent": "None",
+ "target": "Package",
+ "version": "[16.9.4, )"
+ },
+ "coverlet.collector": {
+ "include": "Runtime, Build, Native, ContentFiles, Analyzers, BuildTransitive",
+ "suppressParent": "All",
+ "target": "Package",
+ "version": "[3.0.2, )"
+ },
+ "xunit": {
+ "suppressParent": "None",
+ "target": "Package",
+ "version": "[2.4.2, )"
+ }
+ },
+ "runtimeIdentifierGraphPath": "C:\\Program Files\\dotnet\\sdk\\9.0.100\\RuntimeIdentifierGraph.json"
+ },
+ "netstandard2.0": {
+ "targetAlias": "netstandard2.0",
+ "dependencies": {
+ "Microsoft.NET.Test.Sdk": {
+ "suppressParent": "None",
+ "target": "Package",
+ "version": "[16.9.4, )"
+ },
+ "NETStandard.Library": {
+ "suppressParent": "All",
+ "target": "Package",
+ "version": "[2.0.3, )",
+ "autoReferenced": true
+ },
+ "coverlet.collector": {
+ "include": "Runtime, Build, Native, ContentFiles, Analyzers, BuildTransitive",
+ "suppressParent": "All",
+ "target": "Package",
+ "version": "[3.0.2, )"
+ },
+ "xunit": {
+ "suppressParent": "None",
+ "target": "Package",
+ "version": "[2.4.2, )"
+ }
+ },
+ "imports": [
+ "net461",
+ "net462",
+ "net47",
+ "net471",
+ "net472",
+ "net48",
+ "net481"
+ ],
+ "assetTargetFallback": true,
+ "warn": true,
+ "runtimeIdentifierGraphPath": "C:\\Program Files\\dotnet\\sdk\\9.0.100\\RuntimeIdentifierGraph.json"
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.NupkgProj/project.nuget.cache b/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.NupkgProj/project.nuget.cache
new file mode 100644
index 000000000..fd9b0a74d
--- /dev/null
+++ b/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.NupkgProj/project.nuget.cache
@@ -0,0 +1,21 @@
+{
+ "version": 2,
+ "dgSpecHash": "md8eUrGszbk=",
+ "success": true,
+ "projectFilePath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher.Tests\\FileRestitcher.Tests.csproj",
+ "expectedPackageFiles": [
+ "C:\\Users\\Dimitri\\.nuget\\packages\\coverlet.collector\\3.0.2\\coverlet.collector.3.0.2.nupkg.sha512",
+ "C:\\Users\\Dimitri\\.nuget\\packages\\microsoft.codecoverage\\16.9.4\\microsoft.codecoverage.16.9.4.nupkg.sha512",
+ "C:\\Users\\Dimitri\\.nuget\\packages\\microsoft.net.test.sdk\\16.9.4\\microsoft.net.test.sdk.16.9.4.nupkg.sha512",
+ "C:\\Users\\Dimitri\\.nuget\\packages\\microsoft.netcore.platforms\\1.1.0\\microsoft.netcore.platforms.1.1.0.nupkg.sha512",
+ "C:\\Users\\Dimitri\\.nuget\\packages\\netstandard.library\\2.0.3\\netstandard.library.2.0.3.nupkg.sha512",
+ "C:\\Users\\Dimitri\\.nuget\\packages\\xunit\\2.4.2\\xunit.2.4.2.nupkg.sha512",
+ "C:\\Users\\Dimitri\\.nuget\\packages\\xunit.abstractions\\2.0.3\\xunit.abstractions.2.0.3.nupkg.sha512",
+ "C:\\Users\\Dimitri\\.nuget\\packages\\xunit.analyzers\\1.0.0\\xunit.analyzers.1.0.0.nupkg.sha512",
+ "C:\\Users\\Dimitri\\.nuget\\packages\\xunit.assert\\2.4.2\\xunit.assert.2.4.2.nupkg.sha512",
+ "C:\\Users\\Dimitri\\.nuget\\packages\\xunit.core\\2.4.2\\xunit.core.2.4.2.nupkg.sha512",
+ "C:\\Users\\Dimitri\\.nuget\\packages\\xunit.extensibility.core\\2.4.2\\xunit.extensibility.core.2.4.2.nupkg.sha512",
+ "C:\\Users\\Dimitri\\.nuget\\packages\\xunit.extensibility.execution\\2.4.2\\xunit.extensibility.execution.2.4.2.nupkg.sha512"
+ ],
+ "logs": []
+}
\ No newline at end of file
diff --git a/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.csproj b/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.csproj
index 03a104299..21006ad7d 100644
--- a/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.csproj
+++ b/pkg/FileRestitcher/FileRestitcher.Tests/FileRestitcher.Tests.csproj
@@ -1,9 +1,9 @@
-
+
false
-
+ netstandard2.0;$(TargetFrameworks)
net6.0
net472;$(TargetFrameworks)
@@ -13,8 +13,15 @@
+
+
+
-
+
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+ all
+
+
runtime; build; native; contentfiles; analyzers; buildtransitive
all
diff --git a/pkg/FileRestitcher/FileRestitcher/FileRestitcher.NupkgProj/FileRestitcher.csproj.nuget.dgspec.json b/pkg/FileRestitcher/FileRestitcher/FileRestitcher.NupkgProj/FileRestitcher.csproj.nuget.dgspec.json
new file mode 100644
index 000000000..bbe687ab8
--- /dev/null
+++ b/pkg/FileRestitcher/FileRestitcher/FileRestitcher.NupkgProj/FileRestitcher.csproj.nuget.dgspec.json
@@ -0,0 +1,103 @@
+{
+ "format": 1,
+ "restore": {
+ "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj": {}
+ },
+ "projects": {
+ "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj": {
+ "version": "1.0.0",
+ "restore": {
+ "projectUniqueName": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj",
+ "projectName": "FileRestitcher",
+ "projectPath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj",
+ "packagesPath": "C:\\Users\\Dimitri\\.nuget\\packages\\",
+ "outputPath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.NupkgProj\\",
+ "projectStyle": "PackageReference",
+ "crossTargeting": true,
+ "fallbackFolders": [
+ "C:\\Program Files (x86)\\Microsoft Visual Studio\\Shared\\NuGetPackages"
+ ],
+ "configFilePaths": [
+ "K:\\Proyects_Repos\\TorchSharp\\NuGet.Config",
+ "C:\\Users\\Dimitri\\AppData\\Roaming\\NuGet\\NuGet.Config",
+ "C:\\Program Files (x86)\\NuGet\\Config\\Microsoft.VisualStudio.FallbackLocation.config",
+ "C:\\Program Files (x86)\\NuGet\\Config\\Microsoft.VisualStudio.Offline.config"
+ ],
+ "originalTargetFrameworks": [
+ "net6.0",
+ "netstandard2.0"
+ ],
+ "sources": {
+ "C:\\Program Files (x86)\\Microsoft SDKs\\NuGetPackages\\": {},
+ "https://api.nuget.org/v3/index.json": {}
+ },
+ "frameworks": {
+ "net6.0": {
+ "targetAlias": "net6.0",
+ "projectReferences": {}
+ },
+ "netstandard2.0": {
+ "targetAlias": "netstandard2.0",
+ "projectReferences": {}
+ }
+ },
+ "warningProperties": {
+ "warnAsError": [
+ "NU1605"
+ ]
+ },
+ "restoreAuditProperties": {
+ "enableAudit": "true",
+ "auditLevel": "low",
+ "auditMode": "all"
+ },
+ "SdkAnalysisLevel": "9.0.100"
+ },
+ "frameworks": {
+ "net6.0": {
+ "targetAlias": "net6.0",
+ "imports": [
+ "net461",
+ "net462",
+ "net47",
+ "net471",
+ "net472",
+ "net48",
+ "net481"
+ ],
+ "assetTargetFallback": true,
+ "warn": true,
+ "frameworkReferences": {
+ "Microsoft.NETCore.App": {
+ "privateAssets": "all"
+ }
+ },
+ "runtimeIdentifierGraphPath": "C:\\Program Files\\dotnet\\sdk\\9.0.100\\RuntimeIdentifierGraph.json"
+ },
+ "netstandard2.0": {
+ "targetAlias": "netstandard2.0",
+ "dependencies": {
+ "NETStandard.Library": {
+ "suppressParent": "All",
+ "target": "Package",
+ "version": "[2.0.3, )",
+ "autoReferenced": true
+ }
+ },
+ "imports": [
+ "net461",
+ "net462",
+ "net47",
+ "net471",
+ "net472",
+ "net48",
+ "net481"
+ ],
+ "assetTargetFallback": true,
+ "warn": true,
+ "runtimeIdentifierGraphPath": "C:\\Program Files\\dotnet\\sdk\\9.0.100\\RuntimeIdentifierGraph.json"
+ }
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/pkg/FileRestitcher/FileRestitcher/FileRestitcher.NupkgProj/FileRestitcher.csproj.nuget.g.props b/pkg/FileRestitcher/FileRestitcher/FileRestitcher.NupkgProj/FileRestitcher.csproj.nuget.g.props
new file mode 100644
index 000000000..9c25bbe46
--- /dev/null
+++ b/pkg/FileRestitcher/FileRestitcher/FileRestitcher.NupkgProj/FileRestitcher.csproj.nuget.g.props
@@ -0,0 +1,16 @@
+
+
+
+ True
+ NuGet
+ $(MSBuildThisFileDirectory)project.assets.json
+ $(UserProfile)\.nuget\packages\
+ C:\Users\Dimitri\.nuget\packages\;C:\Program Files (x86)\Microsoft Visual Studio\Shared\NuGetPackages
+ PackageReference
+ 6.12.0
+
+
+
+
+
+
\ No newline at end of file
diff --git a/pkg/FileRestitcher/FileRestitcher/FileRestitcher.NupkgProj/FileRestitcher.csproj.nuget.g.targets b/pkg/FileRestitcher/FileRestitcher/FileRestitcher.NupkgProj/FileRestitcher.csproj.nuget.g.targets
new file mode 100644
index 000000000..2192724bc
--- /dev/null
+++ b/pkg/FileRestitcher/FileRestitcher/FileRestitcher.NupkgProj/FileRestitcher.csproj.nuget.g.targets
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/pkg/FileRestitcher/FileRestitcher/FileRestitcher.NupkgProj/project.assets.json b/pkg/FileRestitcher/FileRestitcher/FileRestitcher.NupkgProj/project.assets.json
new file mode 100644
index 000000000..7e747e944
--- /dev/null
+++ b/pkg/FileRestitcher/FileRestitcher/FileRestitcher.NupkgProj/project.assets.json
@@ -0,0 +1,283 @@
+{
+ "version": 3,
+ "targets": {
+ ".NETStandard,Version=v2.0": {
+ "Microsoft.NETCore.Platforms/1.1.0": {
+ "type": "package",
+ "compile": {
+ "lib/netstandard1.0/_._": {}
+ },
+ "runtime": {
+ "lib/netstandard1.0/_._": {}
+ }
+ },
+ "NETStandard.Library/2.0.3": {
+ "type": "package",
+ "dependencies": {
+ "Microsoft.NETCore.Platforms": "1.1.0"
+ },
+ "compile": {
+ "lib/netstandard1.0/_._": {}
+ },
+ "runtime": {
+ "lib/netstandard1.0/_._": {}
+ },
+ "build": {
+ "build/netstandard2.0/NETStandard.Library.targets": {}
+ }
+ }
+ },
+ "net6.0": {}
+ },
+ "libraries": {
+ "Microsoft.NETCore.Platforms/1.1.0": {
+ "sha512": "kz0PEW2lhqygehI/d6XsPCQzD7ff7gUJaVGPVETX611eadGsA3A877GdSlU0LRVMCTH/+P3o2iDTak+S08V2+A==",
+ "type": "package",
+ "path": "microsoft.netcore.platforms/1.1.0",
+ "files": [
+ ".nupkg.metadata",
+ ".signature.p7s",
+ "ThirdPartyNotices.txt",
+ "dotnet_library_license.txt",
+ "lib/netstandard1.0/_._",
+ "microsoft.netcore.platforms.1.1.0.nupkg.sha512",
+ "microsoft.netcore.platforms.nuspec",
+ "runtime.json"
+ ]
+ },
+ "NETStandard.Library/2.0.3": {
+ "sha512": "st47PosZSHrjECdjeIzZQbzivYBJFv6P2nv4cj2ypdI204DO+vZ7l5raGMiX4eXMJ53RfOIg+/s4DHVZ54Nu2A==",
+ "type": "package",
+ "path": "netstandard.library/2.0.3",
+ "files": [
+ ".nupkg.metadata",
+ ".signature.p7s",
+ "LICENSE.TXT",
+ "THIRD-PARTY-NOTICES.TXT",
+ "build/netstandard2.0/NETStandard.Library.targets",
+ "build/netstandard2.0/ref/Microsoft.Win32.Primitives.dll",
+ "build/netstandard2.0/ref/System.AppContext.dll",
+ "build/netstandard2.0/ref/System.Collections.Concurrent.dll",
+ "build/netstandard2.0/ref/System.Collections.NonGeneric.dll",
+ "build/netstandard2.0/ref/System.Collections.Specialized.dll",
+ "build/netstandard2.0/ref/System.Collections.dll",
+ "build/netstandard2.0/ref/System.ComponentModel.Composition.dll",
+ "build/netstandard2.0/ref/System.ComponentModel.EventBasedAsync.dll",
+ "build/netstandard2.0/ref/System.ComponentModel.Primitives.dll",
+ "build/netstandard2.0/ref/System.ComponentModel.TypeConverter.dll",
+ "build/netstandard2.0/ref/System.ComponentModel.dll",
+ "build/netstandard2.0/ref/System.Console.dll",
+ "build/netstandard2.0/ref/System.Core.dll",
+ "build/netstandard2.0/ref/System.Data.Common.dll",
+ "build/netstandard2.0/ref/System.Data.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.Contracts.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.Debug.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.FileVersionInfo.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.Process.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.StackTrace.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.TextWriterTraceListener.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.Tools.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.TraceSource.dll",
+ "build/netstandard2.0/ref/System.Diagnostics.Tracing.dll",
+ "build/netstandard2.0/ref/System.Drawing.Primitives.dll",
+ "build/netstandard2.0/ref/System.Drawing.dll",
+ "build/netstandard2.0/ref/System.Dynamic.Runtime.dll",
+ "build/netstandard2.0/ref/System.Globalization.Calendars.dll",
+ "build/netstandard2.0/ref/System.Globalization.Extensions.dll",
+ "build/netstandard2.0/ref/System.Globalization.dll",
+ "build/netstandard2.0/ref/System.IO.Compression.FileSystem.dll",
+ "build/netstandard2.0/ref/System.IO.Compression.ZipFile.dll",
+ "build/netstandard2.0/ref/System.IO.Compression.dll",
+ "build/netstandard2.0/ref/System.IO.FileSystem.DriveInfo.dll",
+ "build/netstandard2.0/ref/System.IO.FileSystem.Primitives.dll",
+ "build/netstandard2.0/ref/System.IO.FileSystem.Watcher.dll",
+ "build/netstandard2.0/ref/System.IO.FileSystem.dll",
+ "build/netstandard2.0/ref/System.IO.IsolatedStorage.dll",
+ "build/netstandard2.0/ref/System.IO.MemoryMappedFiles.dll",
+ "build/netstandard2.0/ref/System.IO.Pipes.dll",
+ "build/netstandard2.0/ref/System.IO.UnmanagedMemoryStream.dll",
+ "build/netstandard2.0/ref/System.IO.dll",
+ "build/netstandard2.0/ref/System.Linq.Expressions.dll",
+ "build/netstandard2.0/ref/System.Linq.Parallel.dll",
+ "build/netstandard2.0/ref/System.Linq.Queryable.dll",
+ "build/netstandard2.0/ref/System.Linq.dll",
+ "build/netstandard2.0/ref/System.Net.Http.dll",
+ "build/netstandard2.0/ref/System.Net.NameResolution.dll",
+ "build/netstandard2.0/ref/System.Net.NetworkInformation.dll",
+ "build/netstandard2.0/ref/System.Net.Ping.dll",
+ "build/netstandard2.0/ref/System.Net.Primitives.dll",
+ "build/netstandard2.0/ref/System.Net.Requests.dll",
+ "build/netstandard2.0/ref/System.Net.Security.dll",
+ "build/netstandard2.0/ref/System.Net.Sockets.dll",
+ "build/netstandard2.0/ref/System.Net.WebHeaderCollection.dll",
+ "build/netstandard2.0/ref/System.Net.WebSockets.Client.dll",
+ "build/netstandard2.0/ref/System.Net.WebSockets.dll",
+ "build/netstandard2.0/ref/System.Net.dll",
+ "build/netstandard2.0/ref/System.Numerics.dll",
+ "build/netstandard2.0/ref/System.ObjectModel.dll",
+ "build/netstandard2.0/ref/System.Reflection.Extensions.dll",
+ "build/netstandard2.0/ref/System.Reflection.Primitives.dll",
+ "build/netstandard2.0/ref/System.Reflection.dll",
+ "build/netstandard2.0/ref/System.Resources.Reader.dll",
+ "build/netstandard2.0/ref/System.Resources.ResourceManager.dll",
+ "build/netstandard2.0/ref/System.Resources.Writer.dll",
+ "build/netstandard2.0/ref/System.Runtime.CompilerServices.VisualC.dll",
+ "build/netstandard2.0/ref/System.Runtime.Extensions.dll",
+ "build/netstandard2.0/ref/System.Runtime.Handles.dll",
+ "build/netstandard2.0/ref/System.Runtime.InteropServices.RuntimeInformation.dll",
+ "build/netstandard2.0/ref/System.Runtime.InteropServices.dll",
+ "build/netstandard2.0/ref/System.Runtime.Numerics.dll",
+ "build/netstandard2.0/ref/System.Runtime.Serialization.Formatters.dll",
+ "build/netstandard2.0/ref/System.Runtime.Serialization.Json.dll",
+ "build/netstandard2.0/ref/System.Runtime.Serialization.Primitives.dll",
+ "build/netstandard2.0/ref/System.Runtime.Serialization.Xml.dll",
+ "build/netstandard2.0/ref/System.Runtime.Serialization.dll",
+ "build/netstandard2.0/ref/System.Runtime.dll",
+ "build/netstandard2.0/ref/System.Security.Claims.dll",
+ "build/netstandard2.0/ref/System.Security.Cryptography.Algorithms.dll",
+ "build/netstandard2.0/ref/System.Security.Cryptography.Csp.dll",
+ "build/netstandard2.0/ref/System.Security.Cryptography.Encoding.dll",
+ "build/netstandard2.0/ref/System.Security.Cryptography.Primitives.dll",
+ "build/netstandard2.0/ref/System.Security.Cryptography.X509Certificates.dll",
+ "build/netstandard2.0/ref/System.Security.Principal.dll",
+ "build/netstandard2.0/ref/System.Security.SecureString.dll",
+ "build/netstandard2.0/ref/System.ServiceModel.Web.dll",
+ "build/netstandard2.0/ref/System.Text.Encoding.Extensions.dll",
+ "build/netstandard2.0/ref/System.Text.Encoding.dll",
+ "build/netstandard2.0/ref/System.Text.RegularExpressions.dll",
+ "build/netstandard2.0/ref/System.Threading.Overlapped.dll",
+ "build/netstandard2.0/ref/System.Threading.Tasks.Parallel.dll",
+ "build/netstandard2.0/ref/System.Threading.Tasks.dll",
+ "build/netstandard2.0/ref/System.Threading.Thread.dll",
+ "build/netstandard2.0/ref/System.Threading.ThreadPool.dll",
+ "build/netstandard2.0/ref/System.Threading.Timer.dll",
+ "build/netstandard2.0/ref/System.Threading.dll",
+ "build/netstandard2.0/ref/System.Transactions.dll",
+ "build/netstandard2.0/ref/System.ValueTuple.dll",
+ "build/netstandard2.0/ref/System.Web.dll",
+ "build/netstandard2.0/ref/System.Windows.dll",
+ "build/netstandard2.0/ref/System.Xml.Linq.dll",
+ "build/netstandard2.0/ref/System.Xml.ReaderWriter.dll",
+ "build/netstandard2.0/ref/System.Xml.Serialization.dll",
+ "build/netstandard2.0/ref/System.Xml.XDocument.dll",
+ "build/netstandard2.0/ref/System.Xml.XPath.XDocument.dll",
+ "build/netstandard2.0/ref/System.Xml.XPath.dll",
+ "build/netstandard2.0/ref/System.Xml.XmlDocument.dll",
+ "build/netstandard2.0/ref/System.Xml.XmlSerializer.dll",
+ "build/netstandard2.0/ref/System.Xml.dll",
+ "build/netstandard2.0/ref/System.dll",
+ "build/netstandard2.0/ref/mscorlib.dll",
+ "build/netstandard2.0/ref/netstandard.dll",
+ "build/netstandard2.0/ref/netstandard.xml",
+ "lib/netstandard1.0/_._",
+ "netstandard.library.2.0.3.nupkg.sha512",
+ "netstandard.library.nuspec"
+ ]
+ }
+ },
+ "projectFileDependencyGroups": {
+ ".NETStandard,Version=v2.0": [
+ "NETStandard.Library >= 2.0.3"
+ ],
+ "net6.0": []
+ },
+ "packageFolders": {
+ "C:\\Users\\Dimitri\\.nuget\\packages\\": {},
+ "C:\\Program Files (x86)\\Microsoft Visual Studio\\Shared\\NuGetPackages": {}
+ },
+ "project": {
+ "version": "1.0.0",
+ "restore": {
+ "projectUniqueName": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj",
+ "projectName": "FileRestitcher",
+ "projectPath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj",
+ "packagesPath": "C:\\Users\\Dimitri\\.nuget\\packages\\",
+ "outputPath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.NupkgProj\\",
+ "projectStyle": "PackageReference",
+ "crossTargeting": true,
+ "fallbackFolders": [
+ "C:\\Program Files (x86)\\Microsoft Visual Studio\\Shared\\NuGetPackages"
+ ],
+ "configFilePaths": [
+ "K:\\Proyects_Repos\\TorchSharp\\NuGet.Config",
+ "C:\\Users\\Dimitri\\AppData\\Roaming\\NuGet\\NuGet.Config",
+ "C:\\Program Files (x86)\\NuGet\\Config\\Microsoft.VisualStudio.FallbackLocation.config",
+ "C:\\Program Files (x86)\\NuGet\\Config\\Microsoft.VisualStudio.Offline.config"
+ ],
+ "originalTargetFrameworks": [
+ "net6.0",
+ "netstandard2.0"
+ ],
+ "sources": {
+ "C:\\Program Files (x86)\\Microsoft SDKs\\NuGetPackages\\": {},
+ "https://api.nuget.org/v3/index.json": {}
+ },
+ "frameworks": {
+ "net6.0": {
+ "targetAlias": "net6.0",
+ "projectReferences": {}
+ },
+ "netstandard2.0": {
+ "targetAlias": "netstandard2.0",
+ "projectReferences": {}
+ }
+ },
+ "warningProperties": {
+ "warnAsError": [
+ "NU1605"
+ ]
+ },
+ "restoreAuditProperties": {
+ "enableAudit": "true",
+ "auditLevel": "low",
+ "auditMode": "all"
+ },
+ "SdkAnalysisLevel": "9.0.100"
+ },
+ "frameworks": {
+ "net6.0": {
+ "targetAlias": "net6.0",
+ "imports": [
+ "net461",
+ "net462",
+ "net47",
+ "net471",
+ "net472",
+ "net48",
+ "net481"
+ ],
+ "assetTargetFallback": true,
+ "warn": true,
+ "frameworkReferences": {
+ "Microsoft.NETCore.App": {
+ "privateAssets": "all"
+ }
+ },
+ "runtimeIdentifierGraphPath": "C:\\Program Files\\dotnet\\sdk\\9.0.100\\RuntimeIdentifierGraph.json"
+ },
+ "netstandard2.0": {
+ "targetAlias": "netstandard2.0",
+ "dependencies": {
+ "NETStandard.Library": {
+ "suppressParent": "All",
+ "target": "Package",
+ "version": "[2.0.3, )",
+ "autoReferenced": true
+ }
+ },
+ "imports": [
+ "net461",
+ "net462",
+ "net47",
+ "net471",
+ "net472",
+ "net48",
+ "net481"
+ ],
+ "assetTargetFallback": true,
+ "warn": true,
+ "runtimeIdentifierGraphPath": "C:\\Program Files\\dotnet\\sdk\\9.0.100\\RuntimeIdentifierGraph.json"
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/pkg/FileRestitcher/FileRestitcher/FileRestitcher.NupkgProj/project.nuget.cache b/pkg/FileRestitcher/FileRestitcher/FileRestitcher.NupkgProj/project.nuget.cache
new file mode 100644
index 000000000..aab7970d8
--- /dev/null
+++ b/pkg/FileRestitcher/FileRestitcher/FileRestitcher.NupkgProj/project.nuget.cache
@@ -0,0 +1,11 @@
+{
+ "version": 2,
+ "dgSpecHash": "rM+0M7K4/ZA=",
+ "success": true,
+ "projectFilePath": "K:\\Proyects_Repos\\TorchSharp\\pkg\\FileRestitcher\\FileRestitcher\\FileRestitcher.csproj",
+ "expectedPackageFiles": [
+ "C:\\Users\\Dimitri\\.nuget\\packages\\microsoft.netcore.platforms\\1.1.0\\microsoft.netcore.platforms.1.1.0.nupkg.sha512",
+ "C:\\Users\\Dimitri\\.nuget\\packages\\netstandard.library\\2.0.3\\netstandard.library.2.0.3.nupkg.sha512"
+ ],
+ "logs": []
+}
\ No newline at end of file
diff --git a/pkg/FileRestitcher/FileRestitcher/FileRestitcher.csproj b/pkg/FileRestitcher/FileRestitcher/FileRestitcher.csproj
index 3ab2bb061..68dd5b1d2 100644
--- a/pkg/FileRestitcher/FileRestitcher/FileRestitcher.csproj
+++ b/pkg/FileRestitcher/FileRestitcher/FileRestitcher.csproj
@@ -1,10 +1,10 @@
-
+
false
Library
- netstandard2.0
+ netstandard2.0;net6.0
false
-
+
diff --git a/pkg/pack.proj b/pkg/pack.proj
index 3c9db2f98..c05c5e610 100644
--- a/pkg/pack.proj
+++ b/pkg/pack.proj
@@ -1,6 +1,6 @@
-
+
diff --git a/src/Examples.Utils/Examples.Utils.csproj b/src/Examples.Utils/Examples.Utils.csproj
index f798b1389..6fa145333 100644
--- a/src/Examples.Utils/Examples.Utils.csproj
+++ b/src/Examples.Utils/Examples.Utils.csproj
@@ -4,6 +4,9 @@
9.0
+ net6.0
+ net472;$(TargetFrameworks);netstandard2.0
+
net6.0
@@ -17,7 +20,10 @@
-
+
+
+
+
diff --git a/src/Examples.Utils/Vocab.cs b/src/Examples.Utils/Vocab.cs
index 743e4c55c..7a1deb298 100644
--- a/src/Examples.Utils/Vocab.cs
+++ b/src/Examples.Utils/Vocab.cs
@@ -88,12 +88,17 @@ public void Add(KeyValuePair item)
{
Add(item.Key, item.Value);
}
-
+#if NETSTANDARD2_0
+ public bool TryGetValue(string key, out int value)
+ {
+ return _dict.TryGetValue(key, out value);
+ }
+#else
public bool TryGetValue(string key, [MaybeNullWhen(false)] out int value)
{
return _dict.TryGetValue(key, out value);
}
-
+#endif
private Dictionary _dict = new Dictionary();
private int _last = 0;
}
diff --git a/src/Examples/AdversarialExampleGeneration.cs b/src/Examples/AdversarialExampleGeneration.cs
index 7bfc174b2..49bd10956 100644
--- a/src/Examples/AdversarialExampleGeneration.cs
+++ b/src/Examples/AdversarialExampleGeneration.cs
@@ -34,6 +34,8 @@ public class AdversarialExampleGeneration
{
#if NET472_OR_GREATER
private readonly static string _dataLocation = NSPath.Join(Environment.GetFolderPath(Environment.SpecialFolder.DesktopDirectory), "..", "Downloads", "mnist");
+#elif NETSTANDARD2_0
+ private readonly static string _dataLocation = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.DesktopDirectory), "..", "Downloads", "mnist");
#else
private readonly static string _dataLocation = Path.Join(Environment.GetFolderPath(Environment.SpecialFolder.DesktopDirectory), "..", "Downloads", "mnist");
#endif // NET472_OR_GREATER
diff --git a/src/Examples/Examples.csproj b/src/Examples/Examples.csproj
index 3cb0bed27..9b7a980b9 100644
--- a/src/Examples/Examples.csproj
+++ b/src/Examples/Examples.csproj
@@ -5,9 +5,12 @@
true
true
-
+
+ net472;netstandard2.0;$(TargetFrameworks)
9.0
- net6.0
+
+ net6.0
true
false
false
diff --git a/src/Examples/SequenceToSequence.cs b/src/Examples/SequenceToSequence.cs
index 436c05a67..8ff2c6dc5 100644
--- a/src/Examples/SequenceToSequence.cs
+++ b/src/Examples/SequenceToSequence.cs
@@ -6,6 +6,7 @@
using System.Diagnostics;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
+using System.Text.RegularExpressions;
namespace TorchSharp.Examples
{
@@ -26,6 +27,8 @@ public class SequenceToSequence
// This path assumes that you're running this on Windows.
#if NET472_OR_GREATER
private readonly static string _dataLocation = NSPath.Join(Environment.GetFolderPath(Environment.SpecialFolder.DesktopDirectory), "..", "Downloads", "wikitext-2-v1");
+#elif NETSTANDARD2_0
+ private readonly static string _dataLocation = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.DesktopDirectory), "..", "Downloads", "wikitext-2-v1");
#else
private readonly static string _dataLocation = Path.Join(Environment.GetFolderPath(Environment.SpecialFolder.DesktopDirectory), "..", "Downloads", "wikitext-2-v1");
#endif // NET472_OR_GREATER
@@ -251,7 +254,11 @@ private void InitWeights()
public override Tensor forward(Tensor t, Tensor mask)
{
+#if !NETSTANDARD2_0
var src = pos_encoder.call(encoder.call(t) * MathF.Sqrt(ninputs));
+#else
+ var src = pos_encoder.call(encoder.call(t) * (float)Math.Sqrt(ninputs));
+#endif
var enc = transformer_encoder.call(src, mask);
return decoder.call(enc);
}
diff --git a/src/Examples/TextClassification.cs b/src/Examples/TextClassification.cs
index 8fb175718..4cdc79bc1 100644
--- a/src/Examples/TextClassification.cs
+++ b/src/Examples/TextClassification.cs
@@ -36,6 +36,8 @@ public class TextClassification
// This path assumes that you're running this on Windows.
#if NET472_OR_GREATER
private readonly static string _dataLocation = NSPath.Join(Environment.GetFolderPath(Environment.SpecialFolder.DesktopDirectory), "..", "Downloads", "AG_NEWS");
+#elif NETSTANDARD2_0
+ private readonly static string _dataLocation = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.DesktopDirectory), "..", "Downloads", "AG_NEWS");
#else
private readonly static string _dataLocation = Path.Join(Environment.GetFolderPath(Environment.SpecialFolder.DesktopDirectory), "..", "Downloads", "AG_NEWS");
#endif // NET472_OR_GREATER
diff --git a/src/FSharp.Examples/FSharp.Examples.fsproj b/src/FSharp.Examples/FSharp.Examples.fsproj
index 6259714c5..fe3c34a15 100644
--- a/src/FSharp.Examples/FSharp.Examples.fsproj
+++ b/src/FSharp.Examples/FSharp.Examples.fsproj
@@ -5,6 +5,8 @@
true
true
+ net6.0
+ net472;netstandard2.0;$(TargetFrameworks)
net6.0
true
Examples
diff --git a/src/Native/CMakeSettings.json b/src/Native/CMakeSettings.json
index 9204f06eb..11d28e957 100644
--- a/src/Native/CMakeSettings.json
+++ b/src/Native/CMakeSettings.json
@@ -1,4 +1,4 @@
-{
+{
"configurations": [
{
"name": "x64-Debug",
diff --git a/src/Native/LibTorchSharp/CMakeLists.txt b/src/Native/LibTorchSharp/CMakeLists.txt
index 60b61f049..560fba1a2 100644
--- a/src/Native/LibTorchSharp/CMakeLists.txt
+++ b/src/Native/LibTorchSharp/CMakeLists.txt
@@ -1,15 +1,38 @@
project(LibTorchSharp)
+find_package(CUDA)
+if(CUDA_FOUND)
+ include_directories(${CUDA_INCLUDE_DIRS})
+ link_directories(${CUDA_LIBRARY_DIRS})
+ add_compile_definitions(TORCHSHARP_CUDA_TOOLKIT_FOUND)
+endif()
+
+add_compile_definitions(NOMINMAX)
+
+
+#add_library(CUDA::nvToolsExt INTERFACE IMPORTED)
+# ensure that PyTorch is told to use NVTX3 headers
+#target_compile_definitions(CUDA::nvToolsExt INTERFACETORCH_CUDA_USE_NVTX3)
+#target_link_libraries(CUDA::nvToolsExt INTERFACE CUDA::nvtx3)
+
+
+
if(APPLE AND NOT LIBTORCH_ARCH STREQUAL "arm64")
include_directories("/usr/local/include" "/usr/local/opt/llvm/include")
link_directories("/usr/local/lib" "/usr/local/opt/llvm/lib")
endif()
+
+#set(LIBTORCH_PATH "K:/FrameworksForC/LibTorch/libtorch-win-shared-with-deps-2.6.0+cu126")
find_package(Torch REQUIRED PATHS ${LIBTORCH_PATH})
+#find_package(Torch CONFIG)
set(SOURCES
cifar10.h
crc32c.h
+ THSAmp.h
THSAutograd.h
+ THSBFloat16.h
+ THSCuda.h
THSData.h
THSJIT.h
THSNN.h
@@ -21,8 +44,12 @@ set(SOURCES
cifar10.cpp
crc32c.c
THSActivation.cpp
+ THSAmp.cpp
THSAutograd.cpp
- THSData.cpp
+ THSBFloat16.cpp
+ THSCuda.cpp
+ THSConvolution.cpp
+ THSData.cpp
THSFFT.cpp
THSJIT.cpp
THSLinearAlgebra.cpp
@@ -70,6 +97,10 @@ include_directories(${TORCH_INCLUDE_DIRS})
add_library(LibTorchSharp SHARED ${SOURCES} ${RESOURCES})
+if(CUDA_FOUND)
+target_link_libraries(LibTorchSharp ${CUDA_LIBRARIES})
+endif()
+
target_link_libraries(LibTorchSharp ${TORCH_LIBRARIES})
set_property(TARGET LibTorchSharp PROPERTY CXX_STANDARD 14)
diff --git a/src/Native/LibTorchSharp/THSActivation.cpp b/src/Native/LibTorchSharp/THSActivation.cpp
index c89beaab6..966e5afc3 100644
--- a/src/Native/LibTorchSharp/THSActivation.cpp
+++ b/src/Native/LibTorchSharp/THSActivation.cpp
@@ -2,3 +2,331 @@
#include "THSNN.h"
#include
+
+NNModule THSNN_CELU_ctor(const double alpha, const bool inplace, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::CELUOptions().alpha(alpha).inplace(inplace);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_CELU_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_ELU_ctor(const double alpha, const bool inplace, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::ELUOptions().alpha(alpha).inplace(inplace);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_ELU_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_GELU_ctor(NNAnyModule* outAsAnyModule, const char* approximate)
+{
+ //res = create_module(outAsAnyModule);
+ CATCH_RETURN_NNModule(
+ res = create_module(torch::nn::GELUOptions().approximate(std::string(approximate)), outAsAnyModule);
+ );
+}
+
+Tensor THSNN_GELU_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_GLU_ctor(const int64_t dim, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::GLUOptions().dim(dim);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_GLU_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_Hardshrink_ctor(const double lambda, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::HardshrinkOptions(lambda);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_Hardshrink_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_Hardtanh_ctor(const double min_val, const double max_val, const bool inplace, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::HardtanhOptions()
+ .min_val(min_val)
+ .max_val(max_val)
+ .inplace(inplace);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_Hardtanh_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+
+NNModule THSNN_LeakyReLU_ctor(const double negative_sloope, const bool inplace, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::LeakyReLUOptions().negative_slope(negative_sloope).inplace(inplace);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_LeakyReLU_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_LogSoftmax_ctor(int64_t dim, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::LogSoftmaxOptions(dim);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_LogSoftmax_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_Mish_ctor(NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ res = create_module(outAsAnyModule);
+ );
+}
+
+Tensor THSNN_Mish_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_PReLU_ctor(const int64_t nparams, const double init, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::PReLUOptions().num_parameters(nparams).init(init);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_PReLU_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+Tensor THSNN_PReLU_weight(const NNModule module)
+{
+ return get_weight(module);
+}
+
+void THSNN_PReLU_set_weight(const NNModule module, const Tensor weight)
+{
+ set_weight(module, weight);
+}
+
+NNModule THSNN_ReLU_ctor(bool inplace, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::ReLUOptions(inplace);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_ReLU_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_RReLU_ctor(const double lower, const double upper, const bool inplace, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::RReLUOptions().lower(lower).upper(upper).inplace(inplace);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_RReLU_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_ReLU6_ctor(bool inplace, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::ReLU6Options(inplace);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_ReLU6_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_SELU_ctor(bool inplace, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::SELUOptions(inplace);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_SELU_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_Sigmoid_ctor(NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ res = create_module(outAsAnyModule);
+ );
+}
+
+Tensor THSNN_Sigmoid_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_SiLU_ctor(NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ res = create_module(outAsAnyModule);
+ );
+}
+
+Tensor THSNN_SiLU_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_Softmax2d_ctor(NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ res = create_module(outAsAnyModule);
+ );
+}
+
+Tensor THSNN_Softmax2d_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_Softmax_ctor(const int64_t dim, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::SoftmaxOptions(dim);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_Softmax_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_Softmin_ctor(const int64_t dim, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::SoftminOptions(dim);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_Softmin_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_Softplus_ctor(const double beta, const double threshold, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::SoftplusOptions().beta(beta).threshold(threshold);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_Softplus_forward(const NNModule module, const Tensor tensor) {
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_Softshrink_ctor(const double lambda, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::SoftshrinkOptions().lambda(lambda);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_Softshrink_forward(const NNModule module, const Tensor tensor) {
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_Softsign_ctor(NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ res = create_module(outAsAnyModule);
+ );
+}
+
+Tensor THSNN_Softsign_forward(const NNModule module, const Tensor tensor) {
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_Tanh_ctor(NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ res = create_module(outAsAnyModule);
+ );
+}
+
+Tensor THSNN_Tanh_forward(const NNModule module, const Tensor tensor)
+{
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_Tanhshrink_ctor(NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ res = create_module(outAsAnyModule);
+ );
+}
+
+Tensor THSNN_Tanhshrink_forward(const NNModule module, const Tensor tensor) {
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
+NNModule THSNN_Threshold_ctor(const double threshold, const double value, const bool inplace, NNAnyModule* outAsAnyModule)
+{
+ CATCH_RETURN_NNModule(
+ auto opts = torch::nn::ThresholdOptions(threshold, value).inplace(inplace);
+ res = create_module(opts, outAsAnyModule);
+ );
+}
+
+Tensor THSNN_Threshold_forward(const NNModule module, const Tensor tensor) {
+ CATCH_TENSOR((*module)->as()->forward(*tensor));
+}
+
diff --git a/src/Native/LibTorchSharp/THSAmp.cpp b/src/Native/LibTorchSharp/THSAmp.cpp
new file mode 100644
index 000000000..79c6da9f2
--- /dev/null
+++ b/src/Native/LibTorchSharp/THSAmp.cpp
@@ -0,0 +1,89 @@
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+#include "THSAmp.h"
+
+#include
+#include
+#include "torch/torch.h"
+#include "torch/cuda.h"
+
+/*void THSAmp_amp_foreach_non_finite_check_and_unscale_(const at::TensorList self, at::Tensor& found_inf, const at::Tensor& inv_scale)
+{
+ torch::_amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale);
+}*/
+
+void THSAmp_amp_foreach_non_finite_check_and_unscale_(Tensor* self, const int64_t tLength, at::Tensor& found_inf, const at::Tensor& inv_scale)
+{
+ torch::_amp_foreach_non_finite_check_and_unscale_(toTensors((torch::Tensor**)self, tLength),found_inf,inv_scale);
+}
+
+Tensor THSAmp_amp_update_scale_(at::Tensor& self, at::Tensor& growth_tracker, const at::Tensor& found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval) {
+ CATCH_TENSOR(torch::_amp_update_scale_(self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval);)
+}
+Tensor THSAmp_amp_update_scale_out(at::Tensor& out, const at::Tensor& self, at::Tensor& growth_tracker, const at::Tensor& found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval){
+ CATCH_TENSOR(torch::_amp_update_scale_out(out, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval);)
+}
+Tensor THSAmp_amp_update_scale_outf(const at::Tensor& self, at::Tensor& growth_tracker, const at::Tensor& found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval, at::Tensor& out){
+ CATCH_TENSOR(torch::_amp_update_scale_outf(self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval, out);)
+}
+
+Tensor THSAMP_amp_update_scale(const at::Tensor& self, const at::Tensor& growth_tracker, const at::Tensor& found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval, Tensor* sec)
+{
+ std::tuple res;
+ CATCH(res = torch::_amp_update_scale(self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval);)
+ *sec = ResultTensor(std::get<1>(res));
+ return ResultTensor(std::get<0>(res));
+}
+
+bool THSAmp_is_torch_function_mode_enabled()
+{
+ return at::impl::torch_function_mode_enabled(); //https://github.com/pytorch/pytorch/blob/2c91e13afc6edcfe0a0e6189a88aae4ecbbf3516/torch/csrc/autograd/init.cpp#L911
+}
+
+bool THSAmp_is_autocast_cache_enabled()
+{
+ return at::autocast::is_autocast_cache_enabled();
+}
+
+bool THSAmp_is_autocast_available(int8_t device)
+{
+ return at::autocast::is_autocast_available((c10::DeviceType)device);
+}
+
+
+bool THSAmp_is_autocast_enabled(int8_t device)
+{
+ return at::autocast::is_autocast_enabled((at::DeviceType)device);
+}
+
+int8_t THSAmp_get_autocast_dtype(int8_t device)
+{
+ return (int8_t)at::autocast::get_autocast_dtype((at::DeviceType)device);
+}
+
+void THSAmp_set_autocast_dtype(int8_t device, int8_t dtype)
+{
+ at::autocast::set_autocast_dtype((at::DeviceType)device, (at::ScalarType)dtype);
+}
+
+void THSAmp_set_autocast_enabled(int8_t device, bool enabled)
+{
+ at::autocast::set_autocast_enabled((at::DeviceType)device, enabled);
+}
+int THSAmp_autocast_increment_nesting()
+{
+ return at::autocast::increment_nesting();
+}
+
+int THSAmp_autocast_decrement_nesting()
+{
+ return at::autocast::decrement_nesting();
+}
+
+void THSAmp_clear_autocast_cache()
+{
+ at::autocast::clear_cache();
+}
+void THSAmp_set_autocast_cache_enabled(bool enabled)
+{
+ at::autocast::set_autocast_cache_enabled(enabled);
+}
\ No newline at end of file
diff --git a/src/Native/LibTorchSharp/THSAmp.h b/src/Native/LibTorchSharp/THSAmp.h
new file mode 100644
index 000000000..4ae115dda
--- /dev/null
+++ b/src/Native/LibTorchSharp/THSAmp.h
@@ -0,0 +1,36 @@
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+#pragma once
+
+#include "../Stdafx.h"
+#include "Utils.h"
+
+//https://github.com/pytorch/pytorch/blob/main/torch/_meta_registrations.py#L5957
+//EXPORT_API(void) THSAmp_amp_foreach_non_finite_check_and_unscale_(const at::TensorList self, at::Tensor& found_inf, const at::Tensor& inv_scale);
+
+EXPORT_API(void) THSAmp_amp_foreach_non_finite_check_and_unscale_(Tensor* self, const int64_t tLength, at::Tensor& found_inf, const at::Tensor& inv_scale);
+
+//EXPORT_API(void) THSAmp_amp_update_scale_(const at::Tensor& self, const at::Tensor& inv_scale);
+
+EXPORT_API(Tensor) THSAmp_amp_update_scale_(at::Tensor& self, at::Tensor& growth_tracker, const at::Tensor& found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval);
+EXPORT_API(Tensor) THSAmp_amp_update_scale_out(at::Tensor& out, const at::Tensor& self, at::Tensor& growth_tracker, const at::Tensor& found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval);
+EXPORT_API(Tensor) THSAmp_amp_update_scale_outf(const at::Tensor& self, at::Tensor& growth_tracker, const at::Tensor& found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval, at::Tensor& out);
+EXPORT_API(Tensor) THSAMP_amp_update_scale(const at::Tensor& self, const at::Tensor& growth_tracker, const at::Tensor& found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval, Tensor* sec);
+
+EXPORT_API(bool) THSAmp_is_torch_function_mode_enabled();
+
+EXPORT_API(bool) THSAmp_is_autocast_cache_enabled();
+
+EXPORT_API(bool) THSAmp_is_autocast_available(int8_t device);
+
+EXPORT_API(bool) THSAmp_is_autocast_enabled(int8_t device);
+EXPORT_API(int8_t) THSAmp_get_autocast_dtype(int8_t device);
+EXPORT_API(void) THSAmp_set_autocast_enabled(int8_t device, bool enabled);
+EXPORT_API(void) THSAmp_set_autocast_dtype(int8_t device, int8_t dtype);
+
+EXPORT_API(int) THSAmp_autocast_increment_nesting();
+EXPORT_API(int) THSAmp_autocast_decrement_nesting();
+
+EXPORT_API(void) THSAmp_set_autocast_cache_enabled(bool enabled);
+EXPORT_API(void) THSAmp_clear_autocast_cache();
+
+//EXPORT_API(bool) THSTorch_jit_is_scripting();
\ No newline at end of file
diff --git a/src/Native/LibTorchSharp/THSBFloat16.cpp b/src/Native/LibTorchSharp/THSBFloat16.cpp
new file mode 100644
index 000000000..9302eb565
--- /dev/null
+++ b/src/Native/LibTorchSharp/THSBFloat16.cpp
@@ -0,0 +1,101 @@
+#include "THSBFloat16.h"
+
+c10::BFloat16 bfloat16_ctor(float value)
+{
+ c10::BFloat16 bf16(value);
+ return bf16;
+}
+
+float op_float(c10::BFloat16 bf16)
+{
+ return static_cast(bf16);
+}
+
+c10::BFloat16 op_add(c10::BFloat16 a, c10::BFloat16 b){
+ return a + b;
+}
+c10::BFloat16 op_sub(c10::BFloat16 a, c10::BFloat16 b) {
+ return a - b;
+}
+c10::BFloat16 op_mul(c10::BFloat16 a, c10::BFloat16 b){
+ return a * b;
+}
+c10::BFloat16 op_div(c10::BFloat16 a, c10::BFloat16 b){
+ return a / b;
+}
+float op_add_float(c10::BFloat16 a, float b) {
+ return a + b;
+}
+float op_sub_float(c10::BFloat16 a, float b) {
+ return a - b;
+}
+float op_mul_float(c10::BFloat16 a, float b) {
+ return a * b;
+}
+float op_div_float(c10::BFloat16 a, float b) {
+ return a / b;
+}
+float op_add_lfloat(float a, c10::BFloat16 b) {
+ return a + b;
+}
+float op_sub_lfloat(float a, c10::BFloat16 b) {
+ return a - b;
+}
+float op_mul_lfloat(float a, c10::BFloat16 b) {
+ return a * b;
+}
+float op_div_lfloat(float a, c10::BFloat16 b) {
+ return a / b;
+}
+double op_add_double(c10::BFloat16 a, double b) {
+ return a + b;
+}
+double op_sub_double(c10::BFloat16 a, double b) {
+ return a - b;
+}
+double op_mul_double(c10::BFloat16 a, double b) {
+ return a * b;
+}
+double op_div_double(c10::BFloat16 a, double b) {
+ return a / b;
+}
+double op_add_ldouble(double a, c10::BFloat16 b) {
+ return a + b;
+}
+double op_sub_ldouble(double a, c10::BFloat16 b) {
+ return a - b;
+}
+double op_mul_ldouble(double a, c10::BFloat16 b) {
+ return a * b;
+}
+double op_div_ldouble(double a, c10::BFloat16 b) {
+ return a / b;
+}
+
+c10::BFloat16 bfloat16_min(c10::BFloat16 bf16) {
+ return std::numeric_limits::min();
+}
+c10::BFloat16 bfloat16_lowest(c10::BFloat16 bf16){
+ return std::numeric_limits::lowest();
+}
+c10::BFloat16 bfloat16_max(c10::BFloat16 bf16){
+ return std::numeric_limits::max();
+}
+c10::BFloat16 bfloat16_epsilon(c10::BFloat16 bf16){
+ return std::numeric_limits::epsilon();
+}
+c10::BFloat16 bfloat16_round_error(c10::BFloat16 bf16) {
+ return std::numeric_limits::round_error();
+}
+c10::BFloat16 bfloat16_infinity(c10::BFloat16 bf16) {
+ return std::numeric_limits::infinity();
+}
+c10::BFloat16 bfloat16_quiet_NaN(c10::BFloat16 bf16) {
+ return std::numeric_limits::quiet_NaN();
+}
+c10::BFloat16 bfloat16_signaling_NaN(c10::BFloat16 bf16) {
+ return std::numeric_limits::signaling_NaN();
+}
+c10::BFloat16 bfloat16_denorm_min(c10::BFloat16 bf16) {
+ return std::numeric_limits::denorm_min();
+}
\ No newline at end of file
diff --git a/src/Native/LibTorchSharp/THSBFloat16.h b/src/Native/LibTorchSharp/THSBFloat16.h
new file mode 100644
index 000000000..05305a472
--- /dev/null
+++ b/src/Native/LibTorchSharp/THSBFloat16.h
@@ -0,0 +1,43 @@
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+#pragma once
+
+#include "../Stdafx.h"
+#include "Utils.h"
+
+#include "c10/util/BFloat16.h"
+//#include "c10/util/BFloat16-inl.h"
+
+EXPORT_API(c10::BFloat16) bfloat16_ctor(float value);
+EXPORT_API(float) op_float(c10::BFloat16 bf16);
+EXPORT_API(c10::BFloat16) op_add(c10::BFloat16 a, c10::BFloat16 b);
+EXPORT_API(c10::BFloat16) op_sub(c10::BFloat16 a, c10::BFloat16 b);
+EXPORT_API(c10::BFloat16) op_mul(c10::BFloat16 a, c10::BFloat16 b);
+EXPORT_API(c10::BFloat16) op_div(c10::BFloat16 a, c10::BFloat16 b);
+
+EXPORT_API(float) op_add_float(c10::BFloat16 a, float b);
+EXPORT_API(float) op_sub_float(c10::BFloat16 a, float b);
+EXPORT_API(float) op_mul_float(c10::BFloat16 a, float b);
+EXPORT_API(float) op_div_float(c10::BFloat16 a, float b);
+EXPORT_API(float) op_add_lfloat(float a, c10::BFloat16 b);
+EXPORT_API(float) op_sub_lfloat(float a, c10::BFloat16 b);
+EXPORT_API(float) op_mul_lfloat(float a, c10::BFloat16 b);
+EXPORT_API(float) op_div_lfloat(float a, c10::BFloat16 b);
+
+EXPORT_API(double) op_add_double(c10::BFloat16 a, double b);
+EXPORT_API(double) op_sub_double(c10::BFloat16 a, double b);
+EXPORT_API(double) op_mul_double(c10::BFloat16 a, double b);
+EXPORT_API(double) op_div_double(c10::BFloat16 a, double b);
+EXPORT_API(double) op_add_ldouble(double a, c10::BFloat16 b);
+EXPORT_API(double) op_sub_ldouble(double a, c10::BFloat16 b);
+EXPORT_API(double) op_mul_ldouble(double a, c10::BFloat16 b);
+EXPORT_API(double) op_div_ldouble(double a, c10::BFloat16 b);
+
+EXPORT_API(c10::BFloat16) bfloat16_min(c10::BFloat16 bf16);
+EXPORT_API(c10::BFloat16) bfloat16_lowest(c10::BFloat16 bf16);
+EXPORT_API(c10::BFloat16) bfloat16_max(c10::BFloat16 bf16);
+EXPORT_API(c10::BFloat16) bfloat16_epsilon(c10::BFloat16 bf16);
+EXPORT_API(c10::BFloat16) bfloat16_round_error(c10::BFloat16 bf16);
+EXPORT_API(c10::BFloat16) bfloat16_infinity(c10::BFloat16 bf16);
+EXPORT_API(c10::BFloat16) bfloat16_quiet_NaN(c10::BFloat16 bf16);
+EXPORT_API(c10::BFloat16) bfloat16_signaling_NaN(c10::BFloat16 bf16);
+EXPORT_API(c10::BFloat16) bfloat16_denorm_min(c10::BFloat16 bf16);
\ No newline at end of file
diff --git a/src/Native/LibTorchSharp/THSConvolution.cpp b/src/Native/LibTorchSharp/THSConvolution.cpp
index 621f8935c..3d8ca6aed 100644
--- a/src/Native/LibTorchSharp/THSConvolution.cpp
+++ b/src/Native/LibTorchSharp/THSConvolution.cpp
@@ -66,6 +66,7 @@ void THSNN_Conv1d_set_weight(const NNModule module, const Tensor weight)
set_weight(module, weight);
}
+
NNModule THSNN_Conv2d_ctor(const int64_t inputChannel, const int64_t outputChannel,
const int64_t kernelSize, const int64_t stride, const int64_t padding,
const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias,
@@ -140,6 +141,13 @@ void THSNN_Conv2d_set_weight(const NNModule module, const Tensor weight)
set_weight(module, weight);
}
+/*void THSNN_Conv2d_print_options(const NNModule module) {
+ auto opt = (*module)->as()->options;
+ ::std::cout << "Conv2d (" << std::to_string(opt.in_channels()) << "," << std::to_string(opt.out_channels()) << ")" << std::endl;
+}*/
+
+
+
NNModule THSNN_Conv3d_ctor(const int64_t inputChannel, const int64_t outputChannel,
const int64_t kernelSize, const int64_t stride, const int64_t padding,
const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias,
diff --git a/src/Native/LibTorchSharp/THSCuda.cpp b/src/Native/LibTorchSharp/THSCuda.cpp
new file mode 100644
index 000000000..baca29615
--- /dev/null
+++ b/src/Native/LibTorchSharp/THSCuda.cpp
@@ -0,0 +1,80 @@
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+#include "THSCuda.h"
+
+#include
+#include
+
+#ifdef CUDA_TOOLKIT_FOUND
+cudaDeviceProp THSCuda_get_device_prop(int device)
+{
+ cudaDeviceProp cdp;
+ //cudaGetDeviceProperties(&cdp, device);
+ cudaGetDeviceProperties_v2(&cdp, device);
+ return cdp;
+}
+#endif
+
+int THSCuda_get_major_compute_capability(int device)
+{
+#ifdef CUDA_TOOLKIT_FOUND
+ return THSCuda_get_device_prop(device).major;
+#else
+ return -1;
+#endif
+}
+
+int THSCuda_get_minor_compute_capability(int device)
+{
+#ifdef CUDA_TOOLKIT_FOUND
+ return THSCuda_get_device_prop(device).minor;
+#else
+ return -1;
+#endif
+}
+
+
+int THSCuda_get_device_count(int* count)
+{
+#ifdef CUDA_TOOLKIT_FOUND
+ return cudaGetDeviceCount(count);
+#else
+ return -1;
+#endif
+}
+
+int THSCuda_get_free_total(int device, int* id, size_t* free, size_t* total)
+{
+#ifdef CUDA_TOOLKIT_FOUND
+ cudaError_t res = cudaSetDevice(device);
+ if (res != CUDA_SUCCESS)
+ return -1;
+ res = cudaGetDevice(id);
+ if (res != CUDA_SUCCESS)
+ return -1;
+ return cudaMemGetInfo(free, total);
+#else
+ return -1;
+#endif
+}
+
+size_t THSCuda_get_total_memory(int device)
+{
+#ifdef CUDA_TOOLKIT_FOUND
+ return THSCuda_get_device_prop(device).totalConstMem;
+#else
+ return 0; //Is size_t (unsigned long) so cant be negative.
+#endif
+ //RETURN_CUDA_DEVICE(THSCuda_get_device_prop(device).totalConstMem)
+}
+
+
+size_t THSCuda_get_global_total_memory(int device)
+{
+#ifdef CUDA_TOOLKIT_FOUND
+ return THSCuda_get_device_prop(device).totalGlobalMem;
+#else
+ return 0;
+#endif
+}
+
+//TODO: implement more function
diff --git a/src/Native/LibTorchSharp/THSCuda.h b/src/Native/LibTorchSharp/THSCuda.h
new file mode 100644
index 000000000..00f1d7d03
--- /dev/null
+++ b/src/Native/LibTorchSharp/THSCuda.h
@@ -0,0 +1,48 @@
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+#pragma once
+
+#include "../Stdafx.h"
+#include "Utils.h"
+#include "torch/torch.h"
+
+#ifdef TORCHSHARP_CUDA_TOOLKIT_FOUND
+//#undef CUDA_TOOLKIT_FOUND
+#define CUDA_TOOLKIT_FOUND 1
+#else
+#undef CUDA_TOOLKIT_FOUND
+#endif
+
+/*#define RETURN_CUDA_DEVICE(x) \
+ if(CUDA_TOOLKIT_FOUND) \
+ return x; \
+ else \
+ return -1; */
+
+#ifdef CUDA_TOOLKIT_FOUND
+#include "cuda.h"
+#include "cuda_runtime_api.h"
+
+cudaDeviceProp THSCuda_get_device_prop(int device=0);
+
+inline int show_available_memory()
+{
+ int num_gpus;
+ size_t free, total;
+ cudaGetDeviceCount(&num_gpus);
+ for (int gpu_id = 0; gpu_id < num_gpus; gpu_id++) {
+ cudaSetDevice(gpu_id);
+ int id;
+ cudaGetDevice(&id);
+ cudaMemGetInfo(&free, &total);
+ std::cout << "GPU " << id << " memory: free=" << free << ", total=" << total << std::endl;
+ }
+ return 0;
+}
+#endif
+
+EXPORT_API(int) THSCuda_get_major_compute_capability(int device);
+EXPORT_API(int) THSCuda_get_minor_compute_capability(int device);
+EXPORT_API(int) THSCuda_get_device_count(int* count);
+EXPORT_API(int) THSCuda_get_free_total(int device, int* id, size_t* free, size_t* total);
+EXPORT_API(size_t) THSCuda_get_total_memory(int device);
+EXPORT_API(size_t) THSCuda_get_global_total_memory(int device);
\ No newline at end of file
diff --git a/src/Native/LibTorchSharp/THSLinearAlgebra.cpp b/src/Native/LibTorchSharp/THSLinearAlgebra.cpp
index 4ed6419db..ea0ab8e8e 100644
--- a/src/Native/LibTorchSharp/THSLinearAlgebra.cpp
+++ b/src/Native/LibTorchSharp/THSLinearAlgebra.cpp
@@ -4,9 +4,15 @@
#include
#include
+#define IS_260_OR_NEWER TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6
+
Tensor THSLinalg_cholesky(const Tensor tensor)
{
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_cholesky(*tensor))
+#else
CATCH_TENSOR(torch::linalg::cholesky(*tensor))
+#endif
}
Tensor THSLinalg_cholesky_ex(const Tensor tensor, bool check_errors, Tensor* info)
@@ -29,7 +35,11 @@ Tensor THSLinalg_cond_float(const Tensor tensor, const double p)
Tensor THSLinalg_cond_str(const Tensor tensor, const char* p)
{
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(p != nullptr ? torch::linalg_cond(*tensor, c10::string_view(p)) : torch::linalg_cond(*tensor))
+#else
CATCH_TENSOR(p != nullptr ? torch::linalg_cond(*tensor, p) : torch::linalg_cond(*tensor))
+#endif
}
Tensor THSLinalg_cond_none(const Tensor tensor)
@@ -44,7 +54,11 @@ Tensor THSLinalg_cross(const Tensor input, const Tensor other, const int64_t dim
Tensor THSLinalg_det(const Tensor tensor)
{
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_det(*tensor))
+#else
CATCH_TENSOR(torch::linalg::det(*tensor))
+#endif
}
Tensor THSTensor_logdet(const Tensor tensor)
@@ -55,7 +69,11 @@ Tensor THSTensor_logdet(const Tensor tensor)
Tensor THSLinalg_slogdet(const Tensor tensor, Tensor* logabsdet)
{
std::tuple res;
+#if IS_260_OR_NEWER
+ CATCH(res = torch::linalg_slogdet(*tensor);)
+#else
CATCH(res = torch::linalg::slogdet(*tensor);)
+#endif
*logabsdet = ResultTensor(std::get<1>(res));
return ResultTensor(std::get<0>(res));
}
@@ -63,7 +81,11 @@ Tensor THSLinalg_slogdet(const Tensor tensor, Tensor* logabsdet)
Tensor THSLinalg_eig(const Tensor tensor, Tensor* eigenvectors)
{
std::tuple res;
+#if IS_260_OR_NEWER
+ CATCH(res = torch::linalg_eig(*tensor);)
+#else
CATCH(res = torch::linalg::eig(*tensor););
+#endif
*eigenvectors = ResultTensor(std::get<1>(res));
return ResultTensor(std::get<0>(res));
}
@@ -93,31 +115,51 @@ Tensor THSLinalg_eigh(const Tensor tensor, const char UPLO, Tensor* eigenvectors
std::string _uplo;
_uplo.push_back(UPLO);
std::tuple res;
+#if IS_260_OR_NEWER
+ CATCH(res = torch::linalg_eigh(*tensor, _uplo););
+#else
CATCH(res = torch::linalg::eigh(*tensor, _uplo););
+#endif
*eigenvectors = ResultTensor(std::get<1>(res));
return ResultTensor(std::get<0>(res));
}
Tensor THSLinalg_eigvals(const Tensor tensor)
{
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_eigvals(*tensor))
+#else
CATCH_TENSOR(torch::linalg::eigvals(*tensor))
+#endif
}
Tensor THSLinalg_eigvalsh(const Tensor tensor, const char UPLO)
{
std::string _uplo;
_uplo.push_back(UPLO);
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_eigvalsh(*tensor, _uplo))
+#else
CATCH_TENSOR(torch::linalg::eigvalsh(*tensor, _uplo))
+#endif
}
Tensor THSLinalg_householder_product(const Tensor tensor, const Tensor tau)
{
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_householder_product(*tensor, *tau))
+#else
CATCH_TENSOR(torch::linalg::householder_product(*tensor, *tau))
+#endif
}
Tensor THSLinalg_inv(const Tensor tensor)
{
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_inv(*tensor))
+#else
CATCH_TENSOR(torch::linalg::inv(*tensor))
+#endif
}
Tensor THSLinalg_inv_ex(const Tensor tensor, bool check_errors, Tensor* info)
@@ -131,7 +173,11 @@ Tensor THSLinalg_inv_ex(const Tensor tensor, bool check_errors, Tensor* info)
Tensor THSLinalg_lstsq_none(const Tensor A, const Tensor B, Tensor* residuals, Tensor* rank, Tensor* singular_values)
{
std::tuple res;
+#if IS_260_OR_NEWER
+ CATCH(res = torch::linalg_lstsq(*A, *B, c10::nullopt, c10::nullopt);)
+#else
CATCH(res = torch::linalg::lstsq(*A, *B, c10::nullopt, c10::nullopt);)
+#endif
*residuals = ResultTensor(std::get<1>(res));
*rank = ResultTensor(std::get<2>(res));
*singular_values = ResultTensor(std::get<3>(res));
@@ -141,7 +187,11 @@ Tensor THSLinalg_lstsq_none(const Tensor A, const Tensor B, Tensor* residuals, T
Tensor THSLinalg_lstsq_rcond(const Tensor A, const Tensor B, const double rcond, Tensor* residuals, Tensor* rank, Tensor* singular_values)
{
std::tuple res;
+#if IS_260_OR_NEWER
+ CATCH(res = torch::linalg_lstsq(*A, *B, rcond, c10::nullopt);)
+#else
CATCH(res = torch::linalg::lstsq(*A, *B, rcond, c10::nullopt);)
+#endif
*residuals = ResultTensor(std::get<1>(res));
*rank = ResultTensor(std::get<2>(res));
*singular_values = ResultTensor(std::get<3>(res));
@@ -151,7 +201,11 @@ Tensor THSLinalg_lstsq_rcond(const Tensor A, const Tensor B, const double rcond,
Tensor THSLinalg_lu(const Tensor A, const bool pivot, Tensor* L, Tensor* U)
{
std::tuple res;
+#if IS_260_OR_NEWER
+ CATCH(res = torch::linalg_lu(*A, pivot);)
+#else
CATCH(res = torch::linalg::lu(*A, pivot);)
+#endif
*L = ResultTensor(std::get<1>(res));
*U = ResultTensor(std::get<2>(res));
return ResultTensor(std::get<0>(res));
@@ -160,7 +214,12 @@ Tensor THSLinalg_lu(const Tensor A, const bool pivot, Tensor* L, Tensor* U)
Tensor THSLinalg_lu_factor(const Tensor A, const bool pivot, Tensor* pivots)
{
std::tuple res;
+#if IS_260_OR_NEWER
+ CATCH(res = torch::linalg_lu_factor(*A, pivot);)
+#else
CATCH(res = torch::linalg::lu_factor(*A, pivot);)
+#endif
+
*pivots = ResultTensor(std::get<1>(res));
return ResultTensor(std::get<0>(res));
}
@@ -190,69 +249,111 @@ Tensor THSLinalg_ldl_solve(const Tensor LD, const Tensor pivots, const Tensor B,
Tensor THSLinalg_matrix_norm(const Tensor tensor, const Scalar ord, const int64_t* dim, const int dim_length, const bool keepdim)
{
auto dims = c10::ArrayRef(dim, dim_length);
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_matrix_norm(*tensor, *ord, dims, keepdim, c10::nullopt))
+#else
CATCH_TENSOR(torch::linalg::matrix_norm(*tensor, *ord, dims, keepdim, c10::nullopt))
+#endif
}
Tensor THSLinalg_matrix_norm_fronuc(const Tensor tensor, const int8_t fronuc, const int64_t* dim, const int dim_length, const bool keepdim)
{
auto dims = c10::ArrayRef(dim, dim_length);
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_matrix_norm(*tensor, (fronuc == 0) ? "fro" : "nuc", dims, keepdim, c10::nullopt))
+#else
CATCH_TENSOR(torch::linalg::matrix_norm(*tensor, (fronuc == 0) ? "fro" : "nuc", dims, keepdim, c10::nullopt))
+#endif
}
Tensor THSLinalg_vector_norm(const Tensor tensor, const Scalar ord, const int64_t* dim, const int dim_length, const bool keepdim)
{
auto dims = c10::ArrayRef(dim, dim_length);
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_vector_norm(*tensor, *ord, dims, keepdim, c10::nullopt))
+#else
CATCH_TENSOR(torch::linalg::vector_norm(*tensor, *ord, dims, keepdim, c10::nullopt))
+#endif
}
Tensor THSLinalg_matrix_rank(const Tensor tensor, const double atol, const bool has_atol, const double rtol, const bool has_rtol, const bool hermitian)
{
auto atol_ = has_atol ? atol : c10::optional();
auto rtol_ = has_rtol ? rtol : c10::optional();
-
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_matrix_rank(*tensor, atol_, rtol_, hermitian))
+#else
CATCH_TENSOR(torch::linalg::matrix_rank(*tensor, atol_, rtol_, hermitian))
+#endif
}
Tensor THSLinalg_matrix_rank_tensor(const Tensor tensor, const Tensor atol, const Tensor rtol, const bool hermitian)
{
const c10::optional atol_ = atol != nullptr ? *atol : c10::optional();
const c10::optional rtol_ = rtol != nullptr ? *rtol : c10::optional();
-
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_matrix_rank(*tensor, atol_, rtol_, hermitian))
+#else
CATCH_TENSOR(torch::linalg::matrix_rank(*tensor, atol_, rtol_, hermitian))
+#endif
}
Tensor THSLinalg_matrix_power(const Tensor tensor, const int64_t n)
{
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_matrix_power(*tensor, n))
+#else
CATCH_TENSOR(torch::linalg::matrix_power(*tensor, n))
+#endif
}
Tensor THSLinalg_multi_dot(const Tensor* tensors, const int length)
{
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_multi_dot(toTensors((torch::Tensor**)tensors, length)))
+#else
CATCH_TENSOR(torch::linalg::multi_dot(toTensors((torch::Tensor**)tensors, length)))
+#endif
}
Tensor THSLinalg_norm_str(const Tensor tensor, const char* p, const int64_t* dim, const int dim_length, const bool keepdim)
{
c10::optional dims = (dim == nullptr) ? c10::nullopt : c10::optional(at::ArrayRef(dim, dim_length));
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_norm(*tensor, c10::string_view(p), dims, keepdim, c10::nullopt))
+#else
CATCH_TENSOR(torch::linalg::norm(*tensor, p, dims, keepdim, c10::nullopt))
+#endif
}
Tensor THSLinalg_norm_float(const Tensor tensor, const double p, const int64_t* dim, const int dim_length, const bool keepdim)
{
c10::optional dims = (dim == nullptr) ? c10::nullopt : c10::optional(at::ArrayRef(dim, dim_length));
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_norm(*tensor, p, dims, keepdim, c10::nullopt))
+#else
CATCH_TENSOR(torch::linalg::norm(*tensor, p, dims, keepdim, c10::nullopt))
+#endif
}
Tensor THSLinalg_norm_int(const Tensor tensor, const int p, const int64_t* dim, const int dim_length, const bool keepdim)
{
c10::optional dims = (dim == nullptr) ? c10::nullopt : c10::optional(at::ArrayRef(dim, dim_length));
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_norm(*tensor, p, dims, keepdim, c10::nullopt))
+#else
CATCH_TENSOR(torch::linalg::norm(*tensor, p, dims, keepdim, c10::nullopt))
+#endif
}
Tensor THSLinalg_norm_opt(const Tensor tensor, const int64_t* dim, const int dim_length, const bool keepdim)
{
c10::optional dims = (dim == nullptr) ? c10::nullopt : c10::optional(at::ArrayRef(dim, dim_length));
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_norm(*tensor, c10::nullopt, dims, keepdim, c10::nullopt))
+#else
CATCH_TENSOR(torch::linalg::norm(*tensor, c10::nullopt, dims, keepdim, c10::nullopt))
+#endif
}
Tensor THSLinalg_pinv(const Tensor tensor, const double atol, const bool has_atol, const double rtol, const bool has_rtol, const bool hermitian)
@@ -273,7 +374,11 @@ Tensor THSLinalg_pinv_tensor(const Tensor tensor, const Tensor atol, const Tenso
Tensor THSLinalg_pinverse(const Tensor tensor, const double rcond, const bool hermitian)
{
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_pinv(*tensor, rcond, hermitian))
+#else
CATCH_TENSOR(torch::linalg::pinv(*tensor, rcond, hermitian))
+#endif
}
Tensor THSLinalg_qr(const Tensor tensor, const char mode, Tensor* R)
@@ -295,31 +400,52 @@ Tensor THSLinalg_qr(const Tensor tensor, const char mode, Tensor* R)
Tensor THSLinalg_solve(const Tensor tensor, Tensor other, bool left)
{
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_solve(*tensor, *other, left))
+#else
CATCH_TENSOR(torch::linalg::solve(*tensor, *other, left))
+#endif
+
}
Tensor THSLinalg_solve_ex(const Tensor tensor, Tensor other, bool left, bool check_errors, Tensor* S)
{
std::tuple res;
+#if IS_260_OR_NEWER
+ CATCH(res = torch::linalg_solve_ex(*tensor, *other, left, check_errors););
+#else
CATCH(res = torch::linalg::solve_ex(*tensor, *other, left, check_errors););
+#endif
*S = ResultTensor(std::get<1>(res));
return ResultTensor(std::get<0>(res));
}
Tensor THSLinalg_solve_triangular(const Tensor tensor, Tensor other, bool upper, bool left, bool unitriangular)
{
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_solve_triangular(*tensor, *other, upper, left, unitriangular))
+#else
CATCH_TENSOR(torch::linalg::solve_triangular(*tensor, *other, upper, left, unitriangular))
+#endif
}
Tensor THSLinalg_solve_triangular_out(const Tensor tensor, Tensor other, bool upper, bool left, bool unitriangular, Tensor result)
{
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_solve_triangular_out(*result, *tensor, *other, upper, left, unitriangular))
+#else
CATCH_TENSOR(torch::linalg::solve_triangular_out(*result, *tensor, *other, upper, left, unitriangular))
+#endif
}
Tensor THSLinalg_svd(const Tensor tensor, const bool full_matrices, Tensor* S, Tensor* Vh)
{
std::tuple res;
+#if IS_260_OR_NEWER
+ CATCH(res = torch::linalg_svd(*tensor, full_matrices, c10::nullopt););
+#else
CATCH(res = torch::linalg::svd(*tensor, full_matrices, c10::nullopt););
+#endif
*S = ResultTensor(std::get<1>(res));
*Vh = ResultTensor(std::get<2>(res));
return ResultTensor(std::get<0>(res));
@@ -327,18 +453,30 @@ Tensor THSLinalg_svd(const Tensor tensor, const bool full_matrices, Tensor* S, T
Tensor THSLinalg_svdvals(const Tensor tensor)
{
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(res = torch::linalg_svdvals(*tensor, c10::nullopt))
+#else
CATCH_TENSOR(res = torch::linalg::svdvals(*tensor, c10::nullopt))
+#endif
}
Tensor THSLinalg_tensorinv(const Tensor tensor, const int64_t ind)
{
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_tensorinv(*tensor, ind))
+#else
CATCH_TENSOR(torch::linalg::tensorinv(*tensor, ind))
+#endif
}
Tensor THSLinalg_tensorsolve(const Tensor tensor, Tensor other, const int64_t* dim, const int dim_length)
{
c10::optional dims = (dim == nullptr) ? c10::nullopt : c10::optional(at::ArrayRef(dim, dim_length));
+#if IS_260_OR_NEWER
+ CATCH_TENSOR(torch::linalg_tensorsolve(*tensor, *other, dims))
+#else
CATCH_TENSOR(torch::linalg::tensorsolve(*tensor, *other, dims))
+#endif
}
Tensor THSLinalg_vander(const Tensor tensor, const int64_t N)
diff --git a/src/Native/LibTorchSharp/THSNN.cpp b/src/Native/LibTorchSharp/THSNN.cpp
index f5e9643e7..90794a012 100644
--- a/src/Native/LibTorchSharp/THSNN.cpp
+++ b/src/Native/LibTorchSharp/THSNN.cpp
@@ -1066,4 +1066,58 @@ Tensor THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key,
auto mask = attention_mask == nullptr ? c10::nullopt : c10::optional(*attention_mask);
CATCH_TENSOR(torch::scaled_dot_product_attention(*query, *key, *value, mask, p, casual));
+}
+
+Tensor THSNN_normalize(Tensor input, float p, const int64_t* dim, float eps, Tensor out)
+{
+ auto opts = torch::nn::functional::NormalizeFuncOptions().p(p).eps(eps).dim(*dim);
+ CATCH_TENSOR(torch::nn::functional::normalize(*input, opts))
+ //CATCH_TENSOR(torch::scaled_dot_product_attention(*query, *key, *value, mask, p, casual));
+}
+
+void THSNN_Print_Module(const NNModule module) {
+ std::ostringstream oss;
+ const std::string name = module->get()->name();
+ oss << name << "(";
+ if (auto* conv2 = (*module)->as())
+ {
+ const auto opt = &conv2->options;
+ oss << opt->in_channels() << "," << opt->out_channels() << ", K=" << opt->kernel_size();
+ oss << ", S=" << opt->stride() << ", P=" << opt->padding().index() << ", D=" << opt->dilation();
+ oss << ", G=" << opt->groups() << ", B=" << opt->bias();
+ }
+ if (auto* bn2 = (*module)->as()) {
+ const auto opt = &bn2->options;
+ oss << opt->num_features() << ", Eps=" << opt->eps() << ", M=" << (opt->momentum().has_value() ? std::to_string(opt->momentum().value()) : "NaN");
+ oss << ", A=" << opt->affine() << ", T=" << opt->track_running_stats();
+ }
+ if(auto* ln = (*module)->as()) //This not printed because the TorchSharp not have a ctor of LayerNorm
+ {
+ const auto opt = ln->options;
+ oss << opt.eps() << ", Elem=" << opt.elementwise_affine() << ", N=[";
+ for(int64_t i=0;i< static_cast(opt.normalized_shape().size());i++)
+ oss << opt.normalized_shape()[i] << ((i == static_cast(opt.normalized_shape().size()-1)) ? "]" : ",");
+ }
+ if (const auto* d2 = (*module)->as()) //This not printed because the TorchSharp not have a ctor of Dropout2d
+ {
+ auto opt = d2->options;
+ oss << opt.p() << ", Inplace=" << opt.inplace();
+ }
+ if(auto* avp2 = (*module)->as())
+ {
+ const auto opt = &avp2->options;
+ oss << "[";
+ for (int64_t i = 0; i < opt->output_size().size(); i++)
+ oss << opt->output_size()->at(i).value() << ((i == opt->output_size().size() - 1) ? "]" : ",");
+ }
+ if (auto* amp2 = (*module)->as())
+ {
+ const auto opt = &2->options;
+ oss << "[";
+ for (int64_t i = 0; i < opt->output_size().size(); i++)
+ oss << opt->output_size()->at(i).value() << ((i == opt->output_size().size() - 1) ? "]" : ",");
+ }
+
+ oss << ")";
+ std::cout << oss.str() << std::endl;
}
\ No newline at end of file
diff --git a/src/Native/LibTorchSharp/THSNN.h b/src/Native/LibTorchSharp/THSNN.h
index f7af6bd1f..2bd59af29 100644
--- a/src/Native/LibTorchSharp/THSNN.h
+++ b/src/Native/LibTorchSharp/THSNN.h
@@ -37,6 +37,144 @@ EXPORT_API(void) THSNN_AnyModule_dispose(const NNAnyModule module);
EXPORT_API(NNModule) THSNN_custom_module(const char* name, Tensor(*forward)(Tensor), NNAnyModule* outAsAnyModule);
+// Pooling
+
+EXPORT_API(NNModule) THSNN_MaxPool1d_ctor(const int64_t* kernelSize, const int64_t* stride, const int64_t* padding, const int64_t* dilation, bool ceil_mode, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_MaxPool1d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(Tensor) THSNN_MaxPool1d_forward_with_indices(const NNModule module, const Tensor tensor, Tensor *indices);
+
+EXPORT_API(NNModule) THSNN_MaxPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, const int64_t* dilation, const int dilationLength, bool ceil_mode, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_MaxPool2d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(Tensor) THSNN_MaxPool2d_forward_with_indices(const NNModule module, const Tensor tensor, Tensor* indices);
+
+EXPORT_API(NNModule) THSNN_MaxPool3d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, const int64_t* dilation, const int dilationLength, bool ceil_mode, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_MaxPool3d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(Tensor) THSNN_MaxPool3d_forward_with_indices(const NNModule module, const Tensor tensor, Tensor* indices);
+
+EXPORT_API(NNModule) THSNN_FractionalMaxPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* outputSize, const int outputSizeLength, const double* outputRatio, const int outputRatioLength, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_FractionalMaxPool2d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(Tensor) THSNN_FractionalMaxPool2d_forward_with_indices(const NNModule module, const Tensor tensor, Tensor* indices);
+
+EXPORT_API(NNModule) THSNN_FractionalMaxPool3d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* outputSize, const int outputSizeLength, const double* outputRatio, const int outputRatioLength, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_FractionalMaxPool3d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(Tensor) THSNN_FractionalMaxPool3d_forward_with_indices(const NNModule module, const Tensor tensor, Tensor* indices);
+
+EXPORT_API(NNModule) THSNN_MaxUnpool1d_ctor(const int64_t* kernelSize, const int64_t* stride, const int64_t* padding, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_MaxUnpool1d_forward(const NNModule module, const Tensor tensor, const Tensor indices, const int64_t* outputSize);
+
+EXPORT_API(NNModule) THSNN_MaxUnpool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_MaxUnpool2d_forward(const NNModule module, const Tensor tensor, const Tensor indices, const int64_t* outputSize, const int outputSizeLength);
+
+EXPORT_API(NNModule) THSNN_MaxUnpool3d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_MaxUnpool3d_forward(const NNModule module, const Tensor tensor, const Tensor indices, const int64_t* outputSize, const int outputSizeLength);
+
+EXPORT_API(NNModule) THSNN_AdaptiveAvgPool1d_ctor(const int64_t* sizes, const int length, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_AdaptiveAvgPool1d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_AdaptiveAvgPool2d_ctor(const int64_t* sizes, const int length, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_AdaptiveAvgPool2d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_AdaptiveAvgPool3d_ctor(const int64_t* sizes, const int length, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_AdaptiveAvgPool3d_forward(const NNModule module, const Tensor tensor);
+
+EXPORT_API(NNModule) THSNN_AdaptiveMaxPool1d_ctor(const int64_t* sizes, const int length, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_AdaptiveMaxPool1d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_AdaptiveMaxPool2d_ctor(const int64_t* sizes, const int length, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_AdaptiveMaxPool2d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_AdaptiveMaxPool3d_ctor(const int64_t* sizes, const int length, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_AdaptiveMaxPool3d_forward(const NNModule module, const Tensor tensor);
+
+EXPORT_API(NNModule) THSNN_AvgPool1d_ctor(const int64_t* kernelSize, const int64_t* stride, const int64_t* padding, bool ceil_mode, bool count_include_pad, int64_t divisor_override, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_AvgPool1d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_AvgPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, bool ceil_mode, bool count_include_pad, int64_t divisor_override, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_AvgPool2d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_AvgPool3d_ctor(const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, const int64_t* padding, const int paddingLength, bool ceil_mode, bool count_include_pad, int64_t divisor_override, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_AvgPool3d_forward(const NNModule module, const Tensor tensor);
+
+EXPORT_API(NNModule) THSNN_LPPool1d_ctor(double norm_type, const int64_t* kernelSize, const int64_t* stride, bool ceil_mode, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_LPPool1d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_LPPool2d_ctor(double norm_type, const int64_t* kernelSize, const int kernelSizeLength, const int64_t* stride, const int strideLength, bool ceil_mode, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_LPPool2d_forward(const NNModule module, const Tensor tensor);
+
+// Padding
+
+EXPORT_API(NNModule) THSNN_ZeroPad2d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule);
+EXPORT_API(NNModule) THSNN_ZeroPad2d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ZeroPad2d_forward(const NNModule module, const Tensor tensor);
+
+EXPORT_API(NNModule) THSNN_ConstantPad1d_ctor(const double value, const int64_t padding, NNAnyModule* outAsAnyModule);
+EXPORT_API(NNModule) THSNN_ConstantPad1d_ctor_tuple(const double value, const int64_t padding_left, const int64_t padding_right, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ConstantPad1d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_ConstantPad2d_ctor(const double value, const int64_t padding, NNAnyModule* outAsAnyModule);
+EXPORT_API(NNModule) THSNN_ConstantPad2d_ctor_tuple(const double value, const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ConstantPad2d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_ConstantPad3d_ctor(const double value, const int64_t padding, NNAnyModule* outAsAnyModule);
+EXPORT_API(NNModule) THSNN_ConstantPad3d_ctor_tuple(const double value, const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, const int64_t padding_front, const int64_t padding_back, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ConstantPad3d_forward(const NNModule module, const Tensor tensor);
+
+EXPORT_API(NNModule) THSNN_ReplicationPad1d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule);
+EXPORT_API(NNModule) THSNN_ReplicationPad1d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ReplicationPad1d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_ReplicationPad2d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule);
+EXPORT_API(NNModule) THSNN_ReplicationPad2d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ReplicationPad2d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_ReplicationPad3d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule);
+EXPORT_API(NNModule) THSNN_ReplicationPad3d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, const int64_t padding_front, const int64_t padding_back, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ReplicationPad3d_forward(const NNModule module, const Tensor tensor);
+
+EXPORT_API(NNModule) THSNN_ReflectionPad1d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule);
+EXPORT_API(NNModule) THSNN_ReflectionPad1d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ReflectionPad1d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_ReflectionPad2d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule);
+EXPORT_API(NNModule) THSNN_ReflectionPad2d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ReflectionPad2d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_ReflectionPad3d_ctor(const int64_t padding, NNAnyModule* outAsAnyModule);
+EXPORT_API(NNModule) THSNN_ReflectionPad3d_ctor_tuple(const int64_t padding_left, const int64_t padding_right, const int64_t padding_top, const int64_t padding_bottom, const int64_t padding_front, const int64_t padding_back, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ReflectionPad3d_forward(const NNModule module, const Tensor tensor);
+
+// Convolution
+
+EXPORT_API(NNModule) THSNN_Conv1d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding, const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Conv1d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(Tensor) THSNN_Conv1d_bias(const NNModule module);
+EXPORT_API(void) THSNN_Conv1d_set_bias(const NNModule module, const Tensor bias);
+EXPORT_API(Tensor) THSNN_Conv1d_weight(const NNModule module);
+EXPORT_API(void) THSNN_Conv1d_set_weight(const NNModule module, const Tensor weight);
+EXPORT_API(NNModule) THSNN_Conv2d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding, const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
+EXPORT_API(NNModule) THSNN_Conv2d_ctor_1(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelX, const int64_t kernelY, const int64_t strideX, const int64_t strideY, const int64_t paddingX, const int64_t paddingY, const int64_t dilationX, const int64_t dilationY, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Conv2d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(Tensor) THSNN_Conv2d_weight(const NNModule module);
+EXPORT_API(void) THSNN_Conv2d_set_weight(const NNModule module, const Tensor weight);
+EXPORT_API(Tensor) THSNN_Conv2d_bias(const NNModule module);
+EXPORT_API(void) THSNN_Conv2d_set_bias(const NNModule module, const Tensor bias);
+//EXPORT_API(void) THSNN_Conv2d_print_options(const NNModule module);
+EXPORT_API(NNModule) THSNN_Conv3d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding, const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
+EXPORT_API(NNModule) THSNN_Conv3d_ctor_1(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelX, const int64_t kernelY, const int64_t kernelZ, const int64_t strideX, const int64_t strideY, const int64_t strideZ, const int64_t paddingX, const int64_t paddingY, const int64_t paddingZ, const int64_t dilationX, const int64_t dilationY, const int64_t dilationZ, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Conv3d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(Tensor) THSNN_Conv3d_weight(const NNModule module);
+EXPORT_API(void) THSNN_Conv3d_set_weight(const NNModule module, const Tensor weight);
+EXPORT_API(Tensor) THSNN_Conv3d_bias(const NNModule module);
+EXPORT_API(void) THSNN_Conv3d_set_bias(const NNModule module, const Tensor bias);
+
+EXPORT_API(NNModule) THSNN_ConvTranspose1d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding, const int64_t output_padding, const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ConvTranspose1d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(Tensor) THSNN_ConvTranspose1d_bias(const NNModule module);
+EXPORT_API(void) THSNN_ConvTranspose1d_set_bias(const NNModule module, const Tensor bias);
+EXPORT_API(Tensor) THSNN_ConvTranspose1d_weight(const NNModule module);
+EXPORT_API(void) THSNN_ConvTranspose1d_set_weight(const NNModule module, const Tensor weight);
+EXPORT_API(NNModule) THSNN_ConvTranspose2d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding, const int64_t output_padding, const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
+EXPORT_API(NNModule) THSNN_ConvTranspose2d_ctor_1(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelX, const int64_t kernelY, const int64_t strideX, const int64_t strideY, const int64_t paddingX, const int64_t paddingY, const int64_t output_paddingX, const int64_t output_paddingY, const int64_t dilationX, const int64_t dilationY, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ConvTranspose2d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(Tensor) THSNN_ConvTranspose2d_weight(const NNModule module);
+EXPORT_API(void) THSNN_ConvTranspose2d_set_weight(const NNModule module, const Tensor weight);
+EXPORT_API(Tensor) THSNN_ConvTranspose2d_bias(const NNModule module);
+EXPORT_API(void) THSNN_ConvTranspose2d_set_bias(const NNModule module, const Tensor bias);
+EXPORT_API(NNModule) THSNN_ConvTranspose3d_ctor(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelSize, const int64_t stride, const int64_t padding, const int64_t output_padding, const int64_t dilation, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
+EXPORT_API(NNModule) THSNN_ConvTranspose3d_ctor_1(const int64_t inputChannel, const int64_t outputChannel, const int64_t kernelX, const int64_t kernelY, const int64_t kernelZ, const int64_t strideX, const int64_t strideY, const int64_t strideZ, const int64_t paddingX, const int64_t paddingY, const int64_t paddingZ, const int64_t output_paddingX, const int64_t output_paddingY, const int64_t output_paddingZ, const int64_t dilationX, const int64_t dilationY, const int64_t dilationZ, const int64_t paddingMode, const int64_t groups, const bool bias, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ConvTranspose3d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(Tensor) THSNN_ConvTranspose3d_weight(const NNModule module);
+EXPORT_API(void) THSNN_ConvTranspose3d_set_weight(const NNModule module, const Tensor weight);
+EXPORT_API(Tensor) THSNN_ConvTranspose3d_bias(const NNModule module);
+EXPORT_API(void) THSNN_ConvTranspose3d_set_bias(const NNModule module, const Tensor bias);
+
// Normalization
EXPORT_API(Tensor) THSNN_normalize(const Tensor input, const double p, const int64_t dim, const double eps);
@@ -75,6 +213,61 @@ EXPORT_API(Tensor) THSNN_interpolate(const Tensor input, const int64_t* size, co
EXPORT_API(Tensor) THSNN_grid_sample(const Tensor input, const Tensor grid, const int8_t mode, const int8_t padding_mode, const int8_t align_corners);
EXPORT_API(Tensor) THSNN_affine_grid(const Tensor theta, const int64_t* size, const int size_len, const bool align_corners);
+// Activation functions
+
+EXPORT_API(NNModule) THSNN_CELU_ctor(const double alpha, const bool inplace, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_CELU_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_ELU_ctor(const double alpha, const bool inplace, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ELU_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_GELU_ctor(NNAnyModule* outAsAnyModule, const char* approximate);
+EXPORT_API(Tensor) THSNN_GELU_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_GLU_ctor(const int64_t dim, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_GLU_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_Hardshrink_ctor(const double lambda, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Hardshrink_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_Hardtanh_ctor(const double min_val, const double max_val, const bool inplace, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Hardtanh_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_LeakyReLU_ctor(const double negative_sloope, const bool inplace, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_LeakyReLU_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_Mish_ctor(NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Mish_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_PReLU_ctor(const int64_t nparams, const double init, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_PReLU_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(Tensor) THSNN_PReLU_weight(const NNModule module);
+EXPORT_API(void) THSNN_PReLU_set_weight(const NNModule module, const Tensor weight);
+EXPORT_API(NNModule) THSNN_ReLU_ctor(bool inplace, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ReLU_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_ReLU6_ctor(bool inplace, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_ReLU6_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_RReLU_ctor(const double lower, const double upper, const bool inplace, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_RReLU_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_LogSoftmax_ctor(int64_t dim, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_LogSoftmax_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_SELU_ctor(bool inplace, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_SELU_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_Sigmoid_ctor(NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Sigmoid_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_SiLU_ctor(NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_SiLU_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_Softmax_ctor(const int64_t dim, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Softmax_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_Softmax2d_ctor(NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Softmax2d_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_Softmin_ctor(const int64_t dim, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Softmin_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_Softplus_ctor(const double beta, const double threshold, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Softplus_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_Softshrink_ctor(const double lambda, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Softshrink_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_Softsign_ctor(NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Softsign_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_Tanh_ctor(NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Tanh_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_Tanhshrink_ctor(NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Tanhshrink_forward(const NNModule module, const Tensor tensor);
+EXPORT_API(NNModule) THSNN_Threshold_ctor(const double threshold, const double value, const bool inplace, NNAnyModule* outAsAnyModule);
+EXPORT_API(Tensor) THSNN_Threshold_forward(const NNModule module, const Tensor tensor);
+
// Sparse
EXPORT_API(NNModule) THSNN_Embedding_ctor(const int64_t num_embeddings, const int64_t embedding_dims, const int64_t padding_idx, bool has_pi, const double max_norm, const bool has_mn, const double norm_type, const bool scale_grad_by_freq, const bool sparse, NNAnyModule* outAsAnyModule);
@@ -230,6 +423,7 @@ EXPORT_API(Tensor) THSNN_pairwise_distance(const Tensor input1, const Tensor inp
EXPORT_API(Tensor) THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool casual);
+EXPORT_API(Tensor) THSNN_normalize(const Tensor input, float p, const int64_t* dim, float eps, Tensor out);
// Initializers
EXPORT_API(void) THSNN_initUniform(Tensor twrapper, double low, double high);
@@ -246,3 +440,7 @@ EXPORT_API(PackedSequence) THSNN_pack_padded_sequence(Tensor input, Tensor lengt
EXPORT_API(void) THSNN_pad_packed_sequence(PackedSequence sequence, bool batch_first, double padding_value, int64_t total_length, Tensor* res1, Tensor* res2);
EXPORT_API(Tensor) THSNN_pad_sequence(const Tensor* sequences, const int sequences_len, bool batch_first, double padding_value);
EXPORT_API(PackedSequence) THSNN_pack_sequence(const Tensor* sequences, int sequences_len, bool enforce_sorted);
+
+
+// Printer Modules
+EXPORT_API(void) THSNN_Print_Module(const NNModule module);
diff --git a/src/Native/LibTorchSharp/THSStorage.cpp b/src/Native/LibTorchSharp/THSStorage.cpp
index c966e0e97..4bc8b84e9 100644
--- a/src/Native/LibTorchSharp/THSStorage.cpp
+++ b/src/Native/LibTorchSharp/THSStorage.cpp
@@ -23,3 +23,26 @@ void* THSStorage_data_ptr(const Tensor tensor)
return dp.get();
}
+/*
+int* THSStorage_tensor_to_array_int(const Tensor tensor)
+{
+ return THSStorage_tensor_array(tensor);
+}
+long* THSStorage_tensor_to_array_long(const Tensor tensor)
+{
+ return THSStorage_tensor_array(tensor);
+}
+
+float* THSStorage_tensor_to_array_float(const Tensor tensor)
+{
+ return THSStorage_tensor_array(tensor);
+}
+
+double* THSStorage_tensor_to_array_double(const Tensor tensor)
+{
+ return THSStorage_tensor_array(tensor);
+}
+char* THSStorage_tensor_to_array_char(const Tensor tensor)
+{
+ return THSStorage_tensor_array(tensor);
+}*/
\ No newline at end of file
diff --git a/src/Native/LibTorchSharp/THSStorage.h b/src/Native/LibTorchSharp/THSStorage.h
index e66492e11..53a335921 100644
--- a/src/Native/LibTorchSharp/THSStorage.h
+++ b/src/Native/LibTorchSharp/THSStorage.h
@@ -14,3 +14,19 @@ EXPORT_API(size_t) THSStorage_nbytes(const Tensor tensor);
EXPORT_API(void) THSStorage_set_nbytes(const Tensor tensor, size_t nbytes);
EXPORT_API(void*) THSStorage_data_ptr(const Tensor tensor);
+/*
+template
+T* THSStorage_tensor_array(const Tensor tensor)
+{
+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 4
+ return tensor->data_ptr();
+#else
+ return tensor->data();
+#endif
+}
+
+EXPORT_API(int*) THSStorage_tensor_to_array_int(const Tensor tensor);
+EXPORT_API(long*) THSStorage_tensor_to_array_long(const Tensor tensor);
+EXPORT_API(float*) THSStorage_tensor_to_array_float(const Tensor tensor);
+EXPORT_API(double*) THSStorage_tensor_to_array_double(const Tensor tensor);
+EXPORT_API(char*) THSStorage_tensor_to_array_char(const Tensor tensor);*/
\ No newline at end of file
diff --git a/src/Native/LibTorchSharp/THSTensor.cpp b/src/Native/LibTorchSharp/THSTensor.cpp
index 5cff3ab82..9ed15f273 100644
--- a/src/Native/LibTorchSharp/THSTensor.cpp
+++ b/src/Native/LibTorchSharp/THSTensor.cpp
@@ -836,6 +836,21 @@ void THSTensor_index_put_(Tensor tensor,
auto indices = at::ArrayRef(indicesArray, indicesLength);
CATCH(tensor->index_put_(indices, *value););
}
+/*void THSTensor_index_put_accumulate_(Tensor tensor,
+ const int64_t* indexStarts,
+ const int64_t* indexEnds,
+ const int64_t* indexSteps,
+ const Tensor* indexTensors,
+ const int indicesLength,
+ const Tensor value,
+ bool accumulate)
+{
+ at::indexing::TensorIndex* indicesArray = (at::indexing::TensorIndex*)alloca(indicesLength * sizeof(at::indexing::TensorIndex));
+ memset(indicesArray, 0, indicesLength * sizeof(at::indexing::TensorIndex));
+ completeTensorIndices(indexStarts, indexEnds, indexSteps, indexTensors, indicesArray, indicesLength);
+ auto indices = at::ArrayRef(indicesArray, indicesLength);
+ CATCH(tensor->index_put_({ indices }, *value, accumulate););
+}*/
void THSTensor_index_put_scalar_(Tensor tensor,
const int64_t* indexStarts,
@@ -852,6 +867,37 @@ void THSTensor_index_put_scalar_(Tensor tensor,
CATCH(tensor->index_put_(indices, *value););
}
+/*Tensor THSTensor_index_put(Tensor tensor,
+ const int64_t* indexStarts,
+ const int64_t* indexEnds,
+ const int64_t* indexSteps,
+ const Tensor* indexTensors,
+ const int indicesLength,
+ const Tensor value)
+{
+ at::indexing::TensorIndex* indicesArray = (at::indexing::TensorIndex*)alloca(indicesLength * sizeof(at::indexing::TensorIndex));
+ memset(indicesArray, 0, indicesLength * sizeof(at::indexing::TensorIndex));
+ completeTensorIndices(indexStarts, indexEnds, indexSteps, indexTensors, indicesArray, indicesLength);
+ auto indices = at::ArrayRef(indicesArray, indicesLength);
+ CATCH_TENSOR(tensor->index_put(indices, *value););
+}*/
+
+/*Tensor THSTensor_index_put_accumulate(Tensor tensor,
+ const int64_t* indexStarts,
+ const int64_t* indexEnds,
+ const int64_t* indexSteps,
+ const Tensor* indexTensors,
+ const int indicesLength,
+ const Tensor value,
+ bool accumulate)
+{
+ at::indexing::TensorIndex* indicesArray = (at::indexing::TensorIndex*)alloca(indicesLength * sizeof(at::indexing::TensorIndex));
+ memset(indicesArray, 0, indicesLength * sizeof(at::indexing::TensorIndex));
+ completeTensorIndices(indexStarts, indexEnds, indexSteps, indexTensors, indicesArray, indicesLength);
+ auto indices = at::ArrayRef(indicesArray, indicesLength);
+ CATCH_TENSOR(tensor->index_put({ indices }, *value, accumulate););
+}*/
+
Tensor THSTensor_index_select(Tensor tensor, int64_t dim, Tensor index)
{
CATCH_TENSOR(tensor->index_select(dim, *index));
@@ -1237,6 +1283,11 @@ Tensor THSTensor_reshape(const Tensor tensor, const int64_t* shape, const int le
CATCH_TENSOR(tensor->reshape(at::ArrayRef(shape, length)));
}
+void THSTensor_resize_(const Tensor tensor, const int64_t* shape, const int length)
+{
+ CATCH(tensor->resize_(at::ArrayRef(shape, length)););
+}
+
Tensor THSTensor_rot90(const Tensor tensor, const int64_t k, const int64_t dim1, const int64_t dim2)
{
CATCH_TENSOR(tensor->rot90(k, { dim1, dim2 }));
@@ -1867,6 +1918,21 @@ Tensor THSTensor_to_type_and_device(const Tensor tensor, int8_t scalar_type, con
);
}
+/*Tensor THSTensor_device_and_non_blocking(const Tensor tensor, const int device_type, const int device_index, const bool non_blocking)
+{
+ CATCH_RETURN_Tensor(
+ auto device = c10::Device((c10::DeviceType)device_type, (c10::DeviceIndex)device_index);
+ res = ResultTensor(tensor->to(device, non_blocking, at::ScalarType(scalar_type), false));
+ );
+}*/
+Tensor THSTensor_to_type_and_device_and_non_blocking(const Tensor tensor, int8_t scalar_type, const int device_type, const int device_index,const bool non_blocking)
+{
+ CATCH_RETURN_Tensor(
+ auto device = c10::Device((c10::DeviceType)device_type, (c10::DeviceIndex)device_index);
+ res = ResultTensor(tensor->to(device, at::ScalarType(scalar_type),non_blocking, false));
+ );
+}
+
Tensor THSTensor_triu(const Tensor tensor, const int64_t diagonal, const bool inplace)
{
CATCH_TENSOR(inplace ? tensor->triu_(diagonal) : tensor->triu(diagonal));
@@ -2253,3 +2319,16 @@ Tensor THSTensor_unflatten_names(Tensor tensor, const char** names, const int64_
return nullptr;
}
+
+bool THSTensor_is_coalesce(Tensor tensor)
+{
+ return tensor->is_coalesced();
+}
+
+Tensor THSTensor_coalesce(Tensor tensor)
+{
+ CATCH(
+ return ResultTensor(tensor->coalesce());
+ );
+ return nullptr;
+}
\ No newline at end of file
diff --git a/src/Native/LibTorchSharp/THSTensor.h b/src/Native/LibTorchSharp/THSTensor.h
index ebbdf8302..2a013e4bc 100644
--- a/src/Native/LibTorchSharp/THSTensor.h
+++ b/src/Native/LibTorchSharp/THSTensor.h
@@ -660,6 +660,7 @@ EXPORT_API(void) THSTensor_index_copy_(const Tensor tensor, const int64_t dim, c
EXPORT_API(Tensor) THSTensor_index_fill(const Tensor tensor, const int64_t dim, const Tensor index, const Scalar value);
EXPORT_API(void) THSTensor_index_fill_(const Tensor tensor, const int64_t dim, const Tensor index, const Scalar value);
+
EXPORT_API(Tensor) THSTensor_indices(Tensor tensor);
EXPORT_API(Tensor) THSTensor_index(Tensor tensor,
@@ -669,6 +670,14 @@ EXPORT_API(Tensor) THSTensor_index(Tensor tensor,
const Tensor* indexTensors,
const int indicesLength);
+EXPORT_API(void) THSTensor_index_put_(Tensor tensor,
+ const int64_t* indexStarts,
+ const int64_t* indexEnds,
+ const int64_t* indexSteps,
+ const Tensor* indexTensors,
+ const int indicesLength,
+ const Tensor value);
+
EXPORT_API(void) THSTensor_index_put_scalar_(Tensor tensor,
const int64_t* indexStarts,
const int64_t* indexEnds,
@@ -677,13 +686,31 @@ EXPORT_API(void) THSTensor_index_put_scalar_(Tensor tensor,
const int indicesLength,
const Scalar value);
-EXPORT_API(void) THSTensor_index_put_(Tensor tensor,
+/*EXPORT_API(void) THSTensor_index_put_accumulate_(Tensor tensor,
+ const int64_t* indexStarts,
+ const int64_t* indexEnds,
+ const int64_t* indexSteps,
+ const Tensor* indexTensors,
+ const int indicesLength,
+ const Tensor value,
+ bool accumulate);*/
+
+/*EXPORT_API(Tensor) THSTensor_index_put(Tensor tensor,
const int64_t* indexStarts,
const int64_t* indexEnds,
const int64_t* indexSteps,
const Tensor* indexTensors,
const int indicesLength,
const Tensor value);
+*/
+/*EXPORT_API(Tensor) THSTensor_index_put_accumulate(Tensor tensor,
+ const int64_t* indexStarts,
+ const int64_t* indexEnds,
+ const int64_t* indexSteps,
+ const Tensor* indexTensors,
+ const int indicesLength,
+ const Tensor value,
+ bool accumulate);*/
EXPORT_API(Tensor) THSTensor_index_select(Tensor tensor, int64_t dim, Tensor index);
@@ -1150,6 +1177,8 @@ EXPORT_API(int) THSTensor_requires_grad(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_reshape(const Tensor tensor, const int64_t* shape, const int length);
+EXPORT_API(void) THSTensor_resize_(const Tensor tensor, const int64_t* shape, const int length);
+
EXPORT_API(Tensor) THSTensor_roll(const Tensor tensor, const int64_t* shifts, const int shLength, const int64_t* dims, const int dimLength);
EXPORT_API(Tensor) THSTensor_rot90(const Tensor tensor, const int64_t k, const int64_t dim1, const int64_t dim2);
@@ -1388,6 +1417,10 @@ EXPORT_API(Tensor) THSTensor_to_type(const Tensor tensor, int8_t scalar_type, co
EXPORT_API(Tensor) THSTensor_to_type_and_device(const Tensor tensor, int8_t scalar_type, const int device_type, const int device_index, const bool copy, const bool non_blocking);
+//EXPORT_API(Tensor) THSTensor_device_and_non_blocking(const Tensor tensor, const int device_type, const int device_index, const bool non_blocking);
+
+EXPORT_API(Tensor) THSTensor_to_type_and_device_and_non_blocking(const Tensor tensor, int8_t scalar_type, const int device_type, const int device_index, const bool non_blocking);
+
EXPORT_API(void) THSTensor_topk(const Tensor tensor, Tensor* (*allocator)(size_t length), const int k, const int64_t dim, const bool largest, const bool sorted);
EXPORT_API(Tensor) THSTensor_trunc(const Tensor tensor);
@@ -1783,7 +1816,6 @@ EXPORT_API(Tensor) THSTensor_fftshift(const Tensor tensor, const int64_t* dim, c
EXPORT_API(Tensor) THSTensor_ifftshift(const Tensor tensor, const int64_t* dim, const int dim_length);
-
// Spectral Ops
EXPORT_API(Tensor) THSTensor_bartlett_window(const int64_t len, bool periodic, const int8_t scalar_type, const int device_type, const int device_index, const bool requires_grad);
@@ -1794,3 +1826,6 @@ EXPORT_API(Tensor) THSTensor_kaiser_window(const int64_t len, bool periodic, dou
EXPORT_API(Tensor) THSTensor_stft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool normalized, int64_t onesided, bool return_complex);
EXPORT_API(Tensor) THSTensor_istft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool center, bool normalized, int64_t onesided, int64_t length, bool return_complex);
+
+EXPORT_API(Tensor) THSTensor_coalesce(const Tensor x);
+EXPORT_API(bool) THSTensor_is_coalesce(const Tensor x);
\ No newline at end of file
diff --git a/src/Native/LibTorchSharp/THSTorch.cpp b/src/Native/LibTorchSharp/THSTorch.cpp
index b846557bc..995a2cd37 100644
--- a/src/Native/LibTorchSharp/THSTorch.cpp
+++ b/src/Native/LibTorchSharp/THSTorch.cpp
@@ -323,4 +323,10 @@ double THSSpecial_erf_scalar(const double x)
double THSSpecial_erfc_scalar(const double x)
{
return erfc(x);
-}
\ No newline at end of file
+}
+
+
+/*bool THSTorch_jit_is_scripting()
+{
+
+}*/
\ No newline at end of file
diff --git a/src/Native/LibTorchSharp/THSTorch.h b/src/Native/LibTorchSharp/THSTorch.h
index 9ab80e828..6b515f64a 100644
--- a/src/Native/LibTorchSharp/THSTorch.h
+++ b/src/Native/LibTorchSharp/THSTorch.h
@@ -4,7 +4,8 @@
#include "../Stdafx.h"
#include "Utils.h"
-
+#include
+//#include
// API.
// Sets manually the seed.
@@ -91,3 +92,4 @@ EXPORT_API(void) THSTorch_dispose_scalar(Scalar scalar);
EXPORT_API(double) THSSpecial_erf_scalar(const double x);
EXPORT_API(double) THSSpecial_erfc_scalar(const double x);
+
diff --git a/src/Native/LibTorchSharp/Utils.h b/src/Native/LibTorchSharp/Utils.h
index 4c3606491..42573753b 100644
--- a/src/Native/LibTorchSharp/Utils.h
+++ b/src/Native/LibTorchSharp/Utils.h
@@ -2,9 +2,8 @@
#pragma once
#include
-
#include "torch/torch.h"
-
+#include
extern thread_local char *torch_last_err;
typedef torch::Tensor *Tensor;
@@ -59,8 +58,24 @@ struct TensorArray {
// Return undefined tensors as nullptr to C#
inline Tensor ResultTensor(const at::Tensor & res)
{
- if (res.defined())
+ if (res.defined()) {
+
+ //TODO: Autocast here only if is INNER-SCOPE
+
+ /*at::Tensor* resT = new torch::Tensor(res);
+ if (at::autocast::is_autocast_cache_enabled()){
+ if (res.is_cuda()) {
+ ::std::cout << "IS CUDA" << std::endl;
+ resT->to(at::autocast::get_autocast_gpu_dtype());
+ }
+ if (res.is_cpu()) {
+ ::std::cout << "IS CPU" << std::endl;
+ resT->to(at::autocast::get_autocast_cpu_dtype());
+ }
+ }
+ return resT;*/
return new torch::Tensor(res);
+ }
else
return nullptr;
}
diff --git a/src/Native/build.cmd b/src/Native/build.cmd
index c805b2608..96ec8cacf 100644
--- a/src/Native/build.cmd
+++ b/src/Native/build.cmd
@@ -148,4 +148,4 @@ exit /B 0
:Failure
:: Build failed
echo Failed to generate native component build project!
-exit /b 1
+exit /b 1
\ No newline at end of file
diff --git a/src/Native/build.proj b/src/Native/build.proj
index 6dbbc70a9..a6898465d 100644
--- a/src/Native/build.proj
+++ b/src/Native/build.proj
@@ -31,7 +31,6 @@
Condition="'$(OS)' != 'Windows_NT'">
-
--stripsymbols
--configuration $(NativeConfiguration) --arch $(TargetArchitecture) $(StripArgs) --libtorchpath $(LibTorchCmakePath)
@@ -44,9 +43,13 @@
-
+
$(NativeConfiguration) $(TargetArchitecture) --libtorchpath $(LibTorchCmakePath)
+
+
+ $(NativeConfiguration) $(TargetArchitecture) --libtorchpath $(CustomLibTorchFullPath)
+
@@ -57,8 +60,7 @@
-
+
diff --git a/src/TorchSharp/Amp/AMPManager.cs b/src/TorchSharp/Amp/AMPManager.cs
new file mode 100644
index 000000000..11bc1aaa2
--- /dev/null
+++ b/src/TorchSharp/Amp/AMPManager.cs
@@ -0,0 +1,215 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Runtime.CompilerServices;
+using TorchSharp.PInvoke;
+
+namespace TorchSharp.Amp
+{
+ [Obsolete("Use AutocastMode instaed", true)]
+ public class AMPManager : IDisposable
+ {
+
+ //TODO: Make Singleton THREADSAFE
+ public class TensorConverter
+ {
+ //public torch.Tensor Tensor;
+ public IntPtr PrevHandle;
+ public IntPtr Handle;
+ public torch.ScalarType Dtype;
+ public torch.ScalarType FastDtype = torch.ScalarType.Float32;
+ public TensorCalledIn Called, Status;
+ public enum TensorCalledIn
+ {
+ OutSide,
+ InsideEnter
+ }
+
+ public TensorConverter(IntPtr handle)
+ {
+ this.PrevHandle = handle;
+ this.Handle = handle;
+ this.Dtype = (torch.ScalarType)NativeMethods.THSTensor_type(handle);
+ this.FastDtype = AutocastMode.GetInstance().GetFastType();
+
+ Status = TensorConverter.TensorCalledIn.InsideEnter;
+ }
+ /*public TensorConverter(torch.Tensor tensor) : this(tensor.handle)
+ {
+ this.Tensor = tensor;
+ }*/
+ }
+
+ public IList TensorsCasts = new List();
+ public bool IsEnter = false;
+ public bool IsDisposed = false;
+ /*public UnorderedMap TensorPtrs= new UnorderedMap();
+ public UnorderedMap TensorMap= new UnorderedMap();*/
+ private AutocastMode autocastMode=null;
+ public bool IsEnabled {
+ get {
+ if (autocastMode == null)
+ return false;
+ return autocastMode.IsEnabled;
+ }
+ }
+
+ private AMPManager(bool enabled)
+ {
+ if (!torch.cuda_is_available())
+ return;
+ autocastMode = AutocastMode.GetInstance(enabled);
+ }
+
+ private static AMPManager Instance;
+ public static AMPManager GetInstance(bool enabled = false)
+ {
+ return Instance ??= new AMPManager(enabled);
+ }
+
+ private torch.ScalarType GetType(IntPtr handle)
+ {
+ return (torch.ScalarType)NativeMethods.THSTensor_type(handle);
+ }
+
+ public IntPtr AutoCast(IntPtr handle)
+ {
+ return ToIf(handle, AutocastMode.GetInstance().GetFastType());
+ }
+
+ public torch.Tensor AutoCast(torch.Tensor tensor)
+ {
+ return new torch.Tensor(AutoCast(tensor.Handle));
+ //return tensor.to(AutocastMode.GetInstance().GetFastType());
+ }
+ public static IntPtr To(IntPtr ptr, torch.ScalarType type)
+ {
+ Debug.WriteLine($"{nameof(AMPManager)} Tensor converting from: {(torch.ScalarType)NativeMethods.THSTensor_type(ptr)} to: {type}");
+ var res = NativeMethods.THSTensor_to_type(ptr, (sbyte)type);
+ if (res == IntPtr.Zero)
+ torch.CheckForErrors();
+ return res;
+ }
+ public static IntPtr ToIf(IntPtr ptr, torch.ScalarType type)
+ {
+ if (!AMPManager.GetInstance().IsEnabled)
+ return ptr;
+ var res = NativeMethods.THSTensor_to_type(ptr, (sbyte)type);
+ if (res == IntPtr.Zero)
+ torch.CheckForErrors();
+ return res;
+ }
+ private void Revert()
+ {
+ for (int i = 0; i < TensorsCasts.Count; i++) {
+ var tc = TensorsCasts[i];
+ //var tt = new torch.Tensor(tc.Handle);
+ //var t = new torch.Tensor(tc.Handle) { handle = To(tc.Handle, tc.Dtype) };
+ //var t = new torch.Tensor(tc.Handle).to(tc.Dtype);
+ tc.Handle= To(tc.Handle, tc.Dtype);
+ if (tc.Handle != tc.PrevHandle)
+ tc.PrevHandle = To(tc.PrevHandle, tc.Dtype);
+ }
+ //Cast Work very well but UNCASTING (if outscope, not working i dont know why...)
+ //TensorsCasts.Clear();
+ }
+
+
+ private int ExistsHandle(IntPtr handle)
+ {
+ for (int i = 0; i < TensorsCasts.Count; i++)
+ if (TensorsCasts[i].PrevHandle == handle || TensorsCasts[i].Handle == handle)
+ return i;
+ return -1;
+ }
+
+ public IntPtr Work(IntPtr handle, IntPtr prev)
+ {
+ if (!this.IsEnabled)
+ return handle;
+ /*if (IsDisposed && !IsEnter) {
+ Revert(); //Is for cleaned all
+ return IntPtr.Zero;
+ }*/
+ var idx = ExistsHandle(handle);
+ Console.WriteLine($"PTR: {handle}, PREV: {prev}, IDX: {idx}, {GetType(handle)}");
+ if (idx == -1) {
+ var tc = new TensorConverter(handle) { Called = IsEnter
+ ? TensorConverter.TensorCalledIn.InsideEnter
+ : TensorConverter.TensorCalledIn.OutSide
+ };
+
+ if (IsEnter)
+ tc.Handle = To(tc.Handle, tc.FastDtype);
+ TensorsCasts.Add(tc);
+ return tc.Handle;
+ }
+ var tcidx = TensorsCasts[idx];
+ tcidx.Handle = handle;
+ return tcidx.Handle;
+ /*if (!IsEnter && IsDisposed) {
+ if (tcidx.Called == TensorConverter.TensorCalledIn.OutSide) { //Is created outside so this can revert
+ //Is From Outside and is disposed, the tensor is created Outside so i will revert this
+ tcidx.PrevHandle = tcidx.Handle;
+ tcidx.Handle = To(tcidx.Handle, tcidx.Dtype);
+ }
+ return tcidx.Handle;
+ }
+ if (GetType(tcidx.Handle) == tcidx.FastDtype)
+ return tcidx.Handle;
+
+ if (IsEnter) {
+ tcidx.PrevHandle = tcidx.Handle;
+ tcidx.Handle = To(tcidx.Handle, tcidx.FastDtype);
+ }
+ return tcidx.Handle;*/
+ }
+
+ public IDisposable Enter()
+ {
+ if (!torch.cuda_is_available())
+ return this;
+ IsEnter = true;
+ IsDisposed = false;
+ autocastMode.SetEnabled(true, torch.CUDA);
+ Debug.WriteLine($"{nameof(AMPManager)} Enter call");
+ return this;
+ }
+ protected virtual void Dispose(bool disposing)
+ {
+ Debug.WriteLine($"{nameof(AMPManager)} Disposed call");
+ IsDisposed = true;
+ IsEnter = false;
+ Revert();
+ //Work(IntPtr.Zero, IntPtr.Zero);
+ autocastMode.Dispose();
+ //Revert();
+ /*TensorPtrs.Dispose();
+ TensorMap.Dispose();*/
+ /*if (!disposedValue) {
+ if (disposing) {
+
+
+ // TODO: dispose managed state (managed objects)
+ }
+
+ // TODO: free unmanaged resources (unmanaged objects) and override finalizer
+ // TODO: set large fields to null
+ disposedValue = true;
+ }*/
+ }
+
+ // // TODO: override finalizer only if 'Dispose(bool disposing)' has code to free unmanaged resources
+ /*~AMPManager()
+ {
+ Dispose(false);
+ }*/
+
+ public void Dispose()
+ {
+ // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
+ Dispose(disposing: true);
+ GC.SuppressFinalize(this);
+ }
+ }
+}
diff --git a/src/TorchSharp/Amp/AutocastMode.cs b/src/TorchSharp/Amp/AutocastMode.cs
new file mode 100644
index 000000000..ef0c8a43c
--- /dev/null
+++ b/src/TorchSharp/Amp/AutocastMode.cs
@@ -0,0 +1,222 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Security.Cryptography;
+using System.Text;
+using System.Threading.Tasks;
+using TorchSharp.PInvoke;
+using TorchSharp.Utils;
+
+namespace TorchSharp.Amp
+{
+ /*public static class Autocast
+ {
+ public static torch.Tensor AutoCast(this torch.Tensor input)
+ {
+ return AutocastMode.GetInstance().CastTensor(input);
+ }
+ }*/
+ //TODO: Should make Singleton and IDisposable on ENTER
+ public sealed class AutocastMode : IDisposable
+ {
+ public bool _enabled=false;
+ public bool IsEnter { private set; get; }=false;
+ public bool IsDisposed = false;
+ private bool prev_cache_enabled, prev;
+ private torch.ScalarType prev_fastdtype;
+ //internal bool Prev;
+ private bool _cache_enabled=false;
+ internal torch.ScalarType fast_dtype = torch.ScalarType.Float32;
+ internal torch.ScalarType? dtype = torch.ScalarType.Float32;
+ public DeviceType device = DeviceType.CUDA;
+ private static AutocastMode instance;
+ public static AutocastMode GetInstance(bool enabled=false)
+ {
+ //https://github.com/pytorch/pytorch/blob/e6ff07f00e04a9b58efb86a3dd70ed7280ae8522/torch/fx/experimental/proxy_tensor.py#L1251
+ return instance ??= new AutocastMode(torch.cuda_is_available() ? torch.CUDA : torch.CPU, enabled:enabled,cache_enabled:true);
+ }
+
+ private AutocastMode(torch.Device dev, torch.ScalarType? dtype = null, bool enabled=true, bool? cache_enabled = null)
+ {
+ //https://pytorch.org/docs/stable/amp.html#cuda-ops-that-can-autocast-to-float16
+ if (dtype == null)
+ dtype = torch.get_autocast_dtype(dev.type);
+ this.device = dev.type;
+ if (!torch.is_autocast_available(device))
+ throw new Exception($"User specified an unsupported autocast device_type {device}");
+ fast_dtype = torch.get_autocast_dtype(device); //If device is CPU this may return as BFloat16
+ _cache_enabled = torch.is_autocast_cache_enabled();
+ if (enabled && !torch.cuda_is_available() && dev.type == DeviceType.CUDA) //Is not available for doing multicast
+ enabled = false;
+ if (this.dtype.HasValue)
+ fast_dtype = dtype.Value;
+ if (cache_enabled.HasValue)
+ _cache_enabled = cache_enabled.Value;
+ if (dev.type != DeviceType.CPU && dev.type != DeviceType.CUDA && enabled)
+ throw new Exception($"Currently autocast does not support {dev.type} only CPU or CUDA");
+ /*if (dev.type == DeviceType.CPU) {
+ if (torch.get_autocast_dtype(device) != torch.ScalarType.Float32) {
+ Debug.WriteLine($"Currently is not support {torch.get_autocast_dtype(device)} on CPU, that feature will be add.");
+ }
+ fast_dtype = torch.ScalarType.Float32;
+ }*/
+ if (dev.type == DeviceType.CPU) {
+ //https://github.com/pytorch/pytorch/blob/e6ff07f00e04a9b58efb86a3dd70ed7280ae8522/torch/amp/autocast_mode.py#L277
+ if (enabled && (fast_dtype != torch.ScalarType.Float16 || fast_dtype != torch.ScalarType.BFloat16)) {
+ Debug.WriteLine($"In CPU autocast, but the target dtype is not suported. Disabling autocast. CPU autocast only supports dtype of {torch.ScalarType.Float16} or {torch.ScalarType.BFloat16}");
+ enabled = false;
+ }
+ } else if (dev.type == DeviceType.CUDA) {
+ if (enabled && fast_dtype == torch.ScalarType.BFloat16 && !torch.cuda.is_bf16_supported())
+ throw new Exception("Current CUDA Device does not support bfloat16. Please switch dtype to float16.");
+ }
+
+ torch.set_autocast_enabled(dev.type, true);
+ this._enabled = enabled;
+ }
+
+ public torch.ScalarType GetFastType()
+ {
+ return torch.get_autocast_dtype(device);
+ }
+ private static torch.ScalarType GetDtype(IntPtr handle)
+ {
+ return (torch.ScalarType)NativeMethods.THSTensor_type(handle);
+ }
+
+ public static IntPtr AutoCast(IntPtr handle)
+ {
+ return ToIf(handle, GetInstance().GetFastType());
+ }
+ public static (IntPtr h1, IntPtr h2) AutoCast(IntPtr handle1, IntPtr handle2)
+ {
+ var ft = GetInstance().GetFastType();
+ return (ToIf(handle1, ft), ToIf(handle2, ft));
+ }
+ public static (IntPtr h1, IntPtr h2, IntPtr h3) AutoCast(IntPtr handle1, IntPtr handle2, IntPtr handle3)
+ {
+ var ft = GetInstance().GetFastType();
+ return (ToIf(handle1, ft), ToIf(handle2, ft), ToIf(handle3, ft));
+ }
+ public static (IntPtr h1, IntPtr h2) AutoCast(IntPtr handle1, IntPtr handle2, torch.ScalarType dtype)
+ {
+ return (ToIf(handle1, dtype), ToIf(handle2, dtype));
+ }
+
+ public static (IntPtr h1, IntPtr h2, IntPtr h3) AutoCast(IntPtr handle1, IntPtr handle2, IntPtr handle3, torch.ScalarType dtype)
+ {
+ return (ToIf(handle1, dtype), ToIf(handle2, dtype), ToIf(handle3, dtype));
+ }
+
+ public static IntPtr AutoCast(IntPtr handle, torch.ScalarType dtype)
+ {
+ return ToIf(handle, dtype);
+ }
+
+ public static torch.Tensor AutoCast(torch.Tensor tensor)
+ {
+ return new torch.Tensor(AutoCast(tensor.Handle));
+ //return tensor.to(AutocastMode.GetInstance().GetFastType());
+ }
+ public static IntPtr To(IntPtr ptr, torch.ScalarType type)
+ {
+ Debug.WriteLine($"{nameof(AutocastMode)} Tensor converting from: {GetDtype(ptr)} to: {type}");
+ var res = NativeMethods.THSTensor_to_type(ptr, (sbyte)type);
+ if (res == IntPtr.Zero)
+ torch.CheckForErrors();
+ return res;
+ }
+
+ private static DeviceType GetDeviceType(IntPtr ptr)
+ {
+ return (DeviceType)NativeMethods.THSTensor_device_type(ptr);
+ }
+ public static IntPtr ToIf(IntPtr ptr, torch.ScalarType type)
+ {
+ if(GetInstance().device != DeviceType.CPU) //Warning: Remove this if is finished and working the struct BFloat16 C10
+ if (!IsAutocastEnabled() || !GetInstance().IsEnter)
+ return ptr;
+ if (GetDtype(ptr) == type) //if already have same dtype is not necesary convert to dtype, right???
+ return ptr;
+
+ //TODO: Check if is from CPU to passing BFloat16 if support
+ /*if (!NativeMethods.THSAmp_is_autocast_enabled(NativeMethods.THSTensor_device_type(ptr)))
+ return ptr;*/
+ var res = NativeMethods.THSTensor_to_type(ptr, (sbyte)type);
+ if (res == IntPtr.Zero)
+ torch.CheckForErrors();
+ return res;
+ }
+ public static IntPtr ToIf(IntPtr ptr, torch.ScalarType type, DeviceType device_type)
+ {
+ bool is_elegible = GetDtype(ptr) != torch.ScalarType.Float64 && GetDeviceType(ptr) == device_type;
+
+ if (!NativeMethods.THSAmp_is_autocast_enabled(NativeMethods.THSTensor_device_type(ptr)))
+ return ptr;
+ var res = NativeMethods.THSTensor_to_type(ptr, (sbyte)type);
+ if (res == IntPtr.Zero)
+ torch.CheckForErrors();
+ return res;
+ }
+
+ public static bool IsAutocastEnabled(DeviceType device = DeviceType.CUDA)
+ {
+ return torch.is_autocast_enabled(!torch.cuda_is_available() ? DeviceType.CPU : device);
+ }
+
+ public IDisposable Enter()
+ {
+ prev_cache_enabled = torch.is_autocast_cache_enabled();
+ prev = torch.is_autocast_enabled(device);
+ prev_fastdtype = torch.get_autocast_dtype(device);
+ torch.set_autocast_enabled(device, _enabled);
+ torch.set_autocast_dtype(device, fast_dtype);
+ torch.autocast_increment_nesting();
+ torch.set_autocast_cache_enabled(_cache_enabled);
+ IsEnter = true;
+ /*if (!_enabled) //Research this, may mbad idea????
+ return new AutocastMode(new torch.Device(DeviceType.CUDA));*/
+ return this;
+ }
+
+ public static IDisposable AutoCastEnter()
+ {
+ return AutocastMode.GetInstance().Enter();
+ }
+
+ public void Disabled()
+ {
+ _enabled = false;
+ Dispose();
+ }
+ private void Dispose(bool disposing)
+ {
+ IsEnter = false;
+ if (torch.autocast_decrement_nesting() == 0)
+ torch.clear_autocast_cache();
+ torch.set_autocast_enabled(device, prev);
+ torch.set_autocast_dtype(device, prev_fastdtype);
+ torch.set_autocast_cache_enabled(prev_cache_enabled);
+ }
+
+ public void Dispose()
+ {
+ Dispose(disposing: true);
+ GC.SuppressFinalize(this);
+ }
+ }
+ ///
+ /// Trying to make Custom Autocast forwarded that mean in Pytorch
+ /// like this @torch.autocast(device_type="cuda")
+ ///
+ public class AutocastAttribute : Attribute
+ {
+ private DeviceType Dev;
+ public AutocastAttribute(DeviceType dev)
+ {
+ Dev = dev;
+ }
+ }
+}
diff --git a/src/TorchSharp/Amp/GradScaler.cs b/src/TorchSharp/Amp/GradScaler.cs
new file mode 100644
index 000000000..4aef1a249
--- /dev/null
+++ b/src/TorchSharp/Amp/GradScaler.cs
@@ -0,0 +1,493 @@
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using TorchSharp.Modules;
+using TorchSharp.Utils;
+
+namespace TorchSharp.Amp
+{
+ public class GradScaler : IDisposable
+ {
+ private bool Enabled;
+ public torch.Device device;
+ private torch.Tensor _scale, _growth_tracker;
+ private float InitScale, InitGrowthTracker;
+ public float _growth_factor { set; get; }
+ public float _backoff_factor { set; get; }
+ private int _growth_interval { set; get; }
+ private UnorderedMap> _per_optimizer_states = new UnorderedMap>();
+ bool disposedValue;
+
+ public enum OptState
+ {
+ Ready,
+ Unscaled,
+ Stepped
+ }
+
+ private UnorderedMap _refresh_per_optimizer_state()
+ {
+ return new UnorderedMap() {
+ { "stage", OptState.Ready }, { "found_inf_per_device", null}
+ };
+ }
+ //https://github.com/pytorch/pytorch/blob/main/torch/amp/grad_scaler.py
+ public GradScaler(torch.Device dev, float init_scale = 2.0e16f, float growth_factor = 2.0f,
+ float backoff_factor = 0.5f, int growth_interval = 2000, bool enabled = true)
+ {
+ //https://gist.github.com/dorpxam/67ad2bc222b2cf567d4a6fc298375e13
+ Debug.Assert(dev.type == DeviceType.CPU || dev.type== DeviceType.CUDA);
+ device = dev;
+ Enabled = enabled;
+ InitScale = init_scale;
+ if (Enabled) {
+ Debug.Assert(growth_factor > 1.0);
+ Debug.Assert(backoff_factor < 1.0);
+ }
+ this._growth_factor = growth_factor;
+ _backoff_factor = backoff_factor;
+ _growth_interval = growth_interval;
+ InitGrowthTracker = 0.0f;
+
+ _per_optimizer_states.SetDefaultDict(_refresh_per_optimizer_state());
+ //throw new NotImplementedException("This need to finish");
+ }
+
+ private Tuple check_scale_growth_tracker(string name)
+ {
+ var fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration.";
+ Debug.Assert(!(_scale is null), $"Attempted {name} but {nameof(_scale)} is None {fix}");
+ Debug.Assert(!(_growth_tracker is null), $"Attempted {name} but {nameof(_growth_tracker)} is None {fix}");
+ return new Tuple(_scale, _growth_tracker);
+ }
+
+
+ private void LazyInitScaleGrowthTracker(torch.Device dev)
+ {
+ Debug.Assert(_growth_tracker is null, "_growth_tracker initialized before _scale");
+
+ _scale = torch.full(1, InitScale, torch.ScalarType.Float32, device: dev);
+ _growth_tracker = torch.full(1, InitGrowthTracker, torch.ScalarType.Int32, device: dev);
+ }
+ //private Dictionary
+
+ //private check_scale_growth_tracker
+ public torch.Tensor scale(torch.Tensor output)
+ {
+ if (!Enabled)
+ return output;
+ if (_scale is null)
+ LazyInitScaleGrowthTracker(output.device);
+ Debug.Assert(!(_scale is null));
+ return output * _scale.to(output.device, output.dtype, true);
+ }
+
+ public IList scale(IList outputs)
+ {
+ apply_scale(outputs);
+ return outputs;
+ }
+ private class MultiDeviceReplicator
+ {
+ private readonly torch.Tensor master;
+
+ internal readonly Dictionary per_device_tensors = new Dictionary();
+ public MultiDeviceReplicator(torch.Tensor master_tensor)
+ {
+ master = master_tensor;
+ }
+
+ public torch.Tensor Get(DeviceType device)
+ {
+ torch.Tensor retval=null;
+ if (!per_device_tensors.ContainsKey(device)) {
+ retval = master.to(new torch.Device(device), true, non_blocking: true);
+ per_device_tensors.Add(device, retval);
+ }
+ return retval;
+ }
+ }
+
+ private torch.Tensor apply_scale(torch.Tensor scale)
+ {
+ IList stash = new List();
+ if (stash.Count == 0) {
+ if (_scale is null) {
+ LazyInitScaleGrowthTracker(scale.device);
+ }
+ stash.Add(new MultiDeviceReplicator(_scale));
+ }
+ return scale * stash[0].Get(scale.device.type);
+ }
+
+ private void apply_scale(IList scales)
+ {
+ for (int i = 0; i < scales.Count; i++)
+ scales[i] = apply_scale(scales[i]);
+ }
+ public Dictionary unscale_grads(torch.optim.Optimizer optimizer, torch.Tensor inv_scale, torch.Tensor found_inf, bool allow_fp16)
+ {
+ var per_device_inv_scale = new MultiDeviceReplicator(inv_scale);
+ var per_device_found_inf= new MultiDeviceReplicator(found_inf);
+ Dictionary>> per_device_and_dtype_grads = new Dictionary>>();
+
+ using (torch.no_grad()) {
+
+ using (var enumer = optimizer.parameters().GetEnumerator()) {
+ while (enumer.MoveNext()) {
+ var param = enumer.Current;
+ if (param is null)
+ continue;
+ if (!allow_fp16 && param.dtype == torch.ScalarType.Float16)
+ throw new Exception("Attempting to unscale FP16 Gradients");
+ torch.Tensor to_unscale;
+ if (param.grad.is_sparse) {
+ if (param.grad.dtype == torch.ScalarType.Float16) {
+ param.grad = param.grad.coalesce();
+ }
+
+ to_unscale = param.grad.SparseValues;
+ } else {
+ to_unscale = param.grad;
+ }
+
+ if (!per_device_and_dtype_grads.ContainsKey(to_unscale.device.type)) {
+ per_device_and_dtype_grads.Add(to_unscale.device.type, new Dictionary>());
+ per_device_and_dtype_grads[to_unscale.device.type].Add(to_unscale.dtype, new List());
+ per_device_and_dtype_grads[to_unscale.device.type][to_unscale.dtype].Add(to_unscale);
+ } else {
+ if (!per_device_and_dtype_grads[to_unscale.device.type].ContainsKey(to_unscale.dtype)) {
+ per_device_and_dtype_grads[to_unscale.device.type].Add(to_unscale.dtype, new List());
+ } else {
+ per_device_and_dtype_grads[to_unscale.device.type][to_unscale.dtype].Add(to_unscale);
+ }
+ }
+
+ }
+ }
+
+ foreach (var d in per_device_and_dtype_grads)
+ foreach (var g in d.Value)
+ torch._amp_foreach_non_finite_check_and_unscale_(g.Value, per_device_found_inf.Get(d.Key), per_device_inv_scale.Get(d.Key));
+
+ }
+
+ return per_device_found_inf.per_device_tensors;
+ }
+
+ public void unscale(torch.optim.Optimizer optimizer)
+ {
+ if (!Enabled)
+ return;
+
+ check_scale_growth_tracker(nameof(unscale));
+ //if(_per_optimizer_states.ContainsKey(optimizer.GetHashCode()))
+
+ var optimizer_state = _per_optimizer_states[optimizer.GetHashCode()];
+ if (optimizer_state["stage"] is OptState state) {
+ if (state == OptState.Unscaled) {
+ throw new Exception($"{nameof(unscale)} has already been called on this optimizer since the last update()");
+ }
+ else if(state == OptState.Stepped)
+ throw new Exception($"{nameof(unscale)} is being called after step()");
+ }
+
+ Debug.Assert(!(_scale is null));
+ var inv_scale = _scale.to(torch.ScalarType.Float64).reciprocal().to(torch.ScalarType.Float32);
+ var found_inf = torch.full(1, 0.0f, torch.ScalarType.Float32,_scale.device);
+
+ optimizer_state["found_inf_per_device"] = unscale_grads(optimizer, inv_scale, found_inf, false);
+
+ optimizer_state["stage"] = OptState.Unscaled;
+ }
+ /*
+ *
+
+ template
+ inline auto sum(PerDeviceTensors const& per_device)
+ {
+ Type sum = Type(0);
+ for (auto&& [_, v] : per_device)
+ sum += v.item();
+ return sum;
+ }
+ *
+ */
+ private Scalar maybe_opt_step(torch.optim.Optimizer optimizer, UnorderedMap optimizer_state, Func closure = null)
+ {
+ //https://github.com/pytorch/pytorch/blob/a00fad017719346bac6e08da0819358146e647e3/torch/amp/grad_scaler.py#L351
+ if (optimizer_state.ContainsKey("found_inf_per_device")) {
+
+ double? retval = 0;
+ if (optimizer_state["found_inf_per_device"] is Dictionary dict) {
+ foreach (var d in dict)
+ {
+ retval += (double)d.Value.item();
+ //retval += d.Value.Sum(x=>x.item());
+ /*foreach(var t in d.Value)
+ retval += t.item();*/
+ //retval += d.Value.item();
+ }
+ /*if (retval.HasValue) {
+ if(retval.Value > 0)
+ return
+ }*/
+
+ //https://gist.github.com/dorpxam/67ad2bc222b2cf567d4a6fc298375e13#file-gradscaler-hpp-L209
+ }
+ /*foreach (var d in optimizer_state)
+ if (d.Value is torch.Tensor t)
+ retval += t.item();*/
+ var res = optimizer.step(closure);
+ if (!(res is null)) {
+ return res.item();
+ }
+
+ /*if (retval == 0)
+ retval = .item();
+ return retval;*/
+ }
+
+ return null;
+ }
+
+ public Scalar step(torch.optim.Optimizer optimizer, Func optimizer_args = null)
+ {
+ if (!Enabled) {
+ var res = optimizer.step(optimizer_args);
+ if (!(res is null))
+ return res.item();
+ return null;
+ }
+
+ if (optimizer_args != null)
+ throw new Exception("Closure use is not currently supported if GradScaler is Enabled");
+
+ /*if (!Enabled) {
+ if(obj.Length == 1 && obj[0] is Func closure)
+ return optimizer.step(closure).item();
+ return null;
+ }*/
+
+ check_scale_growth_tracker(nameof(step));
+ var optimizer_state = _per_optimizer_states[optimizer.GetHashCode()];
+
+ if (optimizer_state["stage"] is OptState state && state == OptState.Stepped)
+ throw new Exception($"{nameof(step)} has already been called since the last update()");
+ Scalar retval=null;
+
+ //https://github.com/pytorch/pytorch/blob/a00fad017719346bac6e08da0819358146e647e3/torch/amp/grad_scaler.py#L398
+ var f = optimizer.GetType().GetField("_step_support_amp_scaling");
+ if (f != null && f.GetValue(optimizer) is bool b && !b) {
+ bool has_grad_scaler = false;//I dont know how deal this...
+ if (has_grad_scaler) {
+
+ throw new NotImplementedException();
+ } else {
+ if (optimizer_state["stage"] is OptState optstate && optstate == OptState.Ready)
+ check_inf_per_device(optimizer);
+ var scaler = _get_scale_async();
+ Debug.Assert(!(scaler is null), "!scaler.is_null()");
+ torch.Tensor found_inf=null;
+ if (optimizer_state["found_inf_per_device"] is torch.Tensor[] ts) {
+ for (int i = 0; i < ts.Length; i++)
+ ts[i].to(scaler.device, true);
+ found_inf=torch.sum(torch.cat(ts));
+ }
+
+ optimizer.grad_scale = (optimizer_state["stage"] as OptState?) == OptState.Unscaled ? null : scaler * ((optimizer.grad_scale is null) ? 1 : optimizer.grad_scale);
+ optimizer.found_inf = found_inf;
+
+ //if(optimizer is SGD ad)
+ //Info: All optimizer have grad_scale and found_inf //https://github.com/pytorch/pytorch/blob/main/torch/optim/adam.py, etc.
+ //DANGER: Optimizer in TorchSharp not have grad_scaler or found_inf, we need grad_scale for https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/amp/grad_scaler.py#L440
+ //optimizer.GetType().GetField("grad_scale").GetValue(optimizer) as torch.Tensor t
+ }
+ retval = optimizer.step().item();
+ optimizer_state["stage"] = OptState.Stepped;
+ //https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/amp/grad_scaler.py#L445
+ return retval;
+ }
+ if (optimizer_state["stage"] is OptState state1 && state1 == OptState.Ready)
+ unscale(optimizer);
+ if (optimizer_state["found_inf_per_device"] is ICollection col)
+ {
+ Debug.Assert(col.Count > 0, "(optimizer_state['found_inf_per_device'] as torch.Tensor).size(0) > 0");
+ }
+ //Debug.Assert((optimizer_state["found_inf_per_device"] as Dictionary>)?.Count > 0, "(optimizer_state['found_inf_per_device'] as torch.Tensor).size(0) > 0");
+ retval = maybe_opt_step(optimizer, optimizer_state, optimizer_args);
+ optimizer_state["stage"] = OptState.Stepped;
+ return retval;
+ }
+
+ private torch.Tensor _get_scale_async()
+ {
+ return _scale;
+ }
+
+ ///
+ ///
+ ///
+ /// only float or torch.Tensor
+ public void update(object new_scale = null)
+ {
+ if (!Enabled)
+ return;
+ var tup = check_scale_growth_tracker("update");
+ _scale = tup.Item1;
+ _growth_tracker = tup.Item2;
+ if (new_scale != null) {
+ Debug.Assert(!(_scale is null));
+ if (new_scale is float f)
+ _scale.fill_(f);
+ else if(new_scale is torch.Tensor t) {
+ string reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or torch.FloatTensor with requires_grad = False.";
+ Debug.Assert(t.device == this.device, reason);
+ Debug.Assert(t.numel() == 1, reason);
+ Debug.Assert(!t.requires_grad, reason);
+ _scale.copy_(t);
+ }
+ } else {
+ List found_infs = new List();
+ foreach (var state in _per_optimizer_states) {
+ if (state.Value["found_inf_per_device"] is Dictionary d) {
+ foreach(var found_inf in d.Values)
+ found_infs.Add(found_inf.to(_scale.device, true));
+ }
+ }
+
+ /*foreach (var found_inf in state.Value) {
+ if (found_inf.Value is torch.Tensor t) {
+ found_infs.Add(t);
+ }
+
+ if (found_inf.Value is List ts) {
+ foreach(var te in ts)
+ found_infs.Add(te);
+ }
+ }*/
+
+ Debug.Assert(found_infs.Count > 0, "No inf checks were recorded prior to update.");
+ torch.Tensor found_inf_combined = found_infs[0];
+ if (found_infs.Count > 1)
+ for (int i = 1; i < found_infs.Count; i++)
+ found_inf_combined += found_infs[i];
+ torch.amp_update_scale_(_scale, _growth_tracker, found_inf_combined, (double)_growth_factor, (double)_backoff_factor, (long)_growth_interval);
+ }
+ //TODO: Implement defaultdict https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/amp/grad_scaler.py#L531
+ }
+
+ public void set_init_growth_tracker(long new_value)
+ {
+ InitGrowthTracker=new_value;
+ }
+
+ public torch.Tensor get_scale_async()
+ {
+ return _scale;
+ }
+ public float get_scale()
+ {
+ if (!this.Enabled)
+ return 1.0f;
+
+ var scale = _get_scale_async();
+ if (scale is null)
+ return InitScale;
+ return scale.item();
+ }
+
+ public float get_growth_factor()
+ {
+ return _growth_factor;
+ }
+
+ public float get_backoff_factor()
+ {
+ return _backoff_factor;
+ }
+
+ public int get_growth_interval()
+ {
+ return _growth_interval;
+ }
+
+ public float get_init_growth_tracker()
+ {
+ return InitGrowthTracker; //TODO: Resarch this... should be int64_t???
+ }
+ public bool IsEnabled()
+ {
+ return this.Enabled;
+ }
+
+ public UnorderedMap state_dict()
+ {
+ if (!Enabled)
+ return null;
+
+ var res = new UnorderedMap();
+ res["scale"] = get_scale();
+ res[nameof(_growth_factor)] = _growth_factor;
+ res[nameof(_backoff_factor)] = _backoff_factor;
+ res[nameof(_growth_interval)] = _growth_interval;
+ res[nameof(_growth_tracker)] = _growth_tracker;
+ return res;
+ }
+
+ public void load_state_dict(Dictionary state_dict)
+ {
+ if (!Enabled)
+ return;
+ if (state_dict.Count == 0)
+ throw new Exception("The source state dict is empty, possibly because it was saved from a disabled instance of GradScaler.");
+ //TODO: implement reflection to set field/properties based on state_dict
+ }
+
+ torch.Tensor check_inf_per_device(torch.optim.Optimizer optimizer)
+ {
+ _scale = check_scale_growth_tracker(nameof(check_inf_per_device)).Item1;
+ var dummy_inv_scale = torch.full(new ReadOnlySpan(new long[] { 0 }), 1.0f, torch.ScalarType.Float32, _scale.device);
+ var foundd_inf = torch.full(new ReadOnlySpan(new long[] { 0 }), 0.0f, torch.ScalarType.Float32, _scale.device);
+ _per_optimizer_states[optimizer.GetHashCode()]["found_inf_per_device"] = unscale_grads(optimizer, dummy_inv_scale, foundd_inf, true);
+ return _per_optimizer_states[optimizer.GetHashCode()]["found_inf_per_device"] as torch.Tensor;
+ }
+
+ private object _found_inf_per_device(torch.optim.Optimizer optimizer)
+ {
+ return _per_optimizer_states[optimizer.GetHashCode()]["found_inf_per_device"];
+ }
+
+ protected virtual void Dispose(bool disposing)
+ {
+ if (!disposedValue) {
+ if (disposing) {
+ _per_optimizer_states.Dispose();
+ _growth_tracker.Dispose();
+ _scale.Dispose();
+ // TODO: dispose managed state (managed objects)
+ }
+
+ // TODO: free unmanaged resources (unmanaged objects) and override finalizer
+ // TODO: set large fields to null
+ disposedValue = true;
+ }
+ }
+
+ // // TODO: override finalizer only if 'Dispose(bool disposing)' has code to free unmanaged resources
+ // ~GradScaler()
+ // {
+ // // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
+ // Dispose(disposing: false);
+ // }
+
+ public void Dispose()
+ {
+ // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
+ Dispose(disposing: true);
+ GC.SuppressFinalize(this);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/TorchSharp/Autograd.cs b/src/TorchSharp/Autograd.cs
index 4c73fce46..d7c29cc24 100644
--- a/src/TorchSharp/Autograd.cs
+++ b/src/TorchSharp/Autograd.cs
@@ -2,6 +2,7 @@
using System;
using System.Linq;
using System.Collections.Generic;
+using TorchSharp.Modules;
using static TorchSharp.PInvoke.NativeMethods;
namespace TorchSharp
@@ -145,6 +146,25 @@ public static IList grad(IList outputs, IList inputs, IL
return results.Array.Select(x => new Tensor(x)).ToList();
}
+ public static IList grad(IList inputs, IEnumerable outputs, IList grad_outputs = null, bool retain_graph = false, bool create_graph = false, bool allow_unused = false)
+ {
+ using var outs = new PinnedArray();
+ using var ins = new PinnedArray();
+ using var grads = new PinnedArray();
+ using var results = new PinnedArray();
+
+ IntPtr insRef = outs.CreateArray(outputs.Select(p => p.Handle).ToArray());
+ IntPtr outsRef = ins.CreateArray(inputs.Select(p => p.Handle).ToArray());
+ IntPtr gradsRef = grad_outputs == null ? IntPtr.Zero : grads.CreateArray(grad_outputs.Select(p => p.Handle).ToArray());
+ long gradsLength = grad_outputs == null ? 0 : grads.Array.Length;
+
+ //https://gist.github.com/dorpxam/67ad2bc222b2cf567d4a6fc298375e13#file-gradscaler_test-hpp-L318
+
+ THSAutograd_grad(outsRef, ins.Array.Length, insRef, outs.Array.Length, gradsRef, gradsLength, retain_graph, create_graph, allow_unused, results.CreateArray);
+ CheckForErrors();
+ return results.Array.Select(x => new Tensor(x)).ToList();
+ }
+
///
/// Computes the sum of gradients of given tensors with respect to graph leaves.
///
diff --git a/src/TorchSharp/LinearAlgebra.cs b/src/TorchSharp/LinearAlgebra.cs
index 0abb63d1d..4e8168b05 100644
--- a/src/TorchSharp/LinearAlgebra.cs
+++ b/src/TorchSharp/LinearAlgebra.cs
@@ -2,6 +2,7 @@
using System;
using System.Linq;
using System.Collections.Generic;
+using TorchSharp.Amp;
using static TorchSharp.PInvoke.NativeMethods;
#nullable enable
@@ -440,6 +441,7 @@ public static Tensor multi_dot(IList tensors)
throw new ArgumentException(nameof(tensors));
}
if (tensors.Count == 1) {
+ tensors[0] = AutocastMode.AutoCast(tensors[0]);
return tensors[0];
}
@@ -448,6 +450,7 @@ public static Tensor multi_dot(IList tensors)
var res = THSLinalg_multi_dot(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero)
torch.CheckForErrors();
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
}
diff --git a/src/TorchSharp/NN/Activation/GELU.cs b/src/TorchSharp/NN/Activation/GELU.cs
index 90c314b99..ec13c0d3d 100644
--- a/src/TorchSharp/NN/Activation/GELU.cs
+++ b/src/TorchSharp/NN/Activation/GELU.cs
@@ -32,23 +32,21 @@ public static partial class torch
{
public static partial class nn
{
- ///
- /// Gaussian Error Linear Units
- ///
- public static GELU GELU()
+ public enum Approx
{
- return new GELU(false);
+ none,
+ tanh
}
-
///
/// Gaussian Error Linear Units
///
- /// Do the operation in-place. Default: False
- public static GELU GELU(bool inplace)
+ ///
+ public static GELU GELU(torch.nn.Approx approximate = Approx.none)
{
- return new GELU(inplace);
+ var handle = THSNN_GELU_ctor(out var boxedHandle, approximate.ToString());
+ if (handle == IntPtr.Zero) { torch.CheckForErrors(); }
+ return new GELU(handle, boxedHandle);
}
-
public static partial class functional
{
///
diff --git a/src/TorchSharp/NN/Activation/Softmin.cs b/src/TorchSharp/NN/Activation/Softmin.cs
index 9ddf9e27a..00614b51f 100644
--- a/src/TorchSharp/NN/Activation/Softmin.cs
+++ b/src/TorchSharp/NN/Activation/Softmin.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.PInvoke.NativeMethods;
@@ -39,7 +40,10 @@ public static partial class nn
///
public static Softmin Softmin(long dim)
{
- return new Softmin(dim);
+ var handle = THSNN_Softmin_ctor(dim, out var boxedHandle);
+ if (handle == IntPtr.Zero) { torch.CheckForErrors(); }
+ handle = AutocastMode.AutoCast(handle, ScalarType.Float32); //Should put this here???
+ return new Softmin(handle, boxedHandle);
}
public static partial class functional
diff --git a/src/TorchSharp/NN/Activation/Softplus.cs b/src/TorchSharp/NN/Activation/Softplus.cs
index 0018c4f5d..b274392cf 100644
--- a/src/TorchSharp/NN/Activation/Softplus.cs
+++ b/src/TorchSharp/NN/Activation/Softplus.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.PInvoke.NativeMethods;
@@ -42,7 +43,10 @@ public static partial class nn
///
public static Softplus Softplus(double beta = 1, double threshold = 20)
{
- return new Softplus(beta, threshold);
+ var handle = THSNN_Softplus_ctor(beta, threshold, out var boxedHandle);
+ if (handle == IntPtr.Zero) { torch.CheckForErrors(); }
+ handle = AutocastMode.AutoCast(handle, ScalarType.Float32); //Should put this here
+ return new Softplus(handle, boxedHandle);
}
public static partial class functional
diff --git a/src/TorchSharp/NN/Bilinear.cs b/src/TorchSharp/NN/Bilinear.cs
index be96adb76..a6081bca7 100644
--- a/src/TorchSharp/NN/Bilinear.cs
+++ b/src/TorchSharp/NN/Bilinear.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using static TorchSharp.PInvoke.NativeMethods;
@@ -7,6 +8,7 @@
#nullable enable
namespace TorchSharp
{
+ using System.Linq;
using Modules;
using TorchSharp.Utils;
@@ -148,7 +150,17 @@ public static Tensor bilinear(Tensor input1, Tensor input2, Tensor weight, Tenso
{
IntPtr bPtr = bias?.Handle ?? IntPtr.Zero;
var res = THSNN_functional_bilinear(input1.Handle, input2.Handle, weight.Handle, bPtr);
- if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ if (res == IntPtr.Zero) { CheckForErrors(); }
+ /*if (AutocastMode.IsAutocastEnabled()) {
+ var st = input1.dtype;
+ var st1 = input2.dtype;
+ var st2 = weight.dtype;
+ var sts = new[] { st, st1, st2 };
+ if (sts.All(x => x == ScalarType.Float16))
+ (handle, tensor1.handle, tensor2.handle) = AutocastMode.AutoCast(handle, tensor1.handle, tensor2.handle, ScalarType.Float16);
+ if (sts.Any(x => x == ScalarType.Float32))
+ (handle, tensor1.handle, tensor2.handle) = AutocastMode.AutoCast(handle, tensor1.handle, tensor2.handle, ScalarType.Float32);
+ }*/
return new Tensor(res);
}
}
diff --git a/src/TorchSharp/NN/Convolution/Conv1D.cs b/src/TorchSharp/NN/Convolution/Conv1D.cs
index 3f57e9cd0..aa0a38801 100644
--- a/src/TorchSharp/NN/Convolution/Conv1D.cs
+++ b/src/TorchSharp/NN/Convolution/Conv1D.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.PInvoke.NativeMethods;
@@ -12,8 +13,14 @@ namespace Modules
{
public sealed class Conv1d : Convolution
{
- internal Conv1d(long in_channels, long out_channels, long kernel_size, long stride, long? padding, Padding? padding_type, long dilation, long groups = 1, bool bias = true, PaddingModes padding_mode = PaddingModes.Zeros, torch.Device? device = null, ScalarType? dtype = null)
- : base(nameof(Conv1d), in_channels, out_channels, new[] { kernel_size }, new[] { stride }, padding.HasValue ? new[] { padding.Value } : null, padding_type, new[] { dilation }, false, new[] { 0L }, groups, bias, padding_mode, device, dtype) { }
+ internal long _dimension, _in_channel, _out_channel, _kernel,_stride, _padding,_dilation,_groups;
+ internal PaddingModes _paddingModes;
+ internal (long, long)? _kernels, _strides, _paddings, _dilations;
+ internal bool _bias;
+ protected Convolution(IntPtr handle, IntPtr boxedHandle, long input_channels) : base(handle, boxedHandle)
+ {
+ this.input_channels = input_channels;
+ }
public override Tensor forward(Tensor input)
{
@@ -54,7 +61,19 @@ public static partial class nn
/// Tensor of shape (N,C_out,L_out)
public static Conv1d Conv1d(long in_channels, long out_channels, long kernel_size, long stride = 1, long padding = 0, long dilation = 1, PaddingModes padding_mode = PaddingModes.Zeros, long groups = 1, bool bias = true, Device? device = null, ScalarType? dtype = null)
{
- return new Conv1d(in_channels, out_channels, kernel_size, stride, padding, null, dilation, groups, bias, padding_mode, device, dtype);
+ var res = THSNN_Conv1d_ctor(in_channels, out_channels, kernelSize, stride, padding, dilation, (long)padding_mode, groups, bias, out var boxedHandle);
+ if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ return new Conv1d(res, boxedHandle, in_channels) {
+ _in_channel = in_channels,
+ _out_channel = out_channels,
+ _kernel = kernelSize,
+ _stride = stride,
+ _padding = padding,
+ _dilation = dilation,
+ _paddingModes = padding_mode,
+ _groups = groups,
+ _bias = bias
+ }.MoveModule(device, dtype);
}
///
@@ -74,7 +93,19 @@ public static Conv1d Conv1d(long in_channels, long out_channels, long kernel_siz
/// Tensor of shape (N,C_out,L_out)
public static Conv1d Conv1d(long in_channels, long out_channels, long kernel_size, Padding padding, long stride = 1, long dilation = 1, PaddingModes padding_mode = PaddingModes.Zeros, long groups = 1, bool bias = true, Device? device = null, ScalarType? dtype = null)
{
- return new Conv1d(in_channels, out_channels, kernel_size, stride, null, padding, dilation, groups, bias, padding_mode, device, dtype);
+ var res = THSNN_Conv1d_ctor(in_channels, out_channels, kernelSize, stride, padding == Padding.Valid ? 0 : -1, dilation, (long)padding_mode, groups, bias, out var boxedHandle);
+ if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ return new Conv1d(res, boxedHandle, in_channels) {
+ _in_channel = in_channels,
+ _out_channel = out_channels,
+ _kernel = kernelSize,
+ _stride = stride,
+ _padding = (long)padding,
+ _dilation = dilation,
+ _paddingModes = padding_mode,
+ _groups = groups,
+ _bias = bias
+ }.MoveModule(device, dtype);
}
public static partial class functional
@@ -109,6 +140,7 @@ public static Tensor conv1d(Tensor input, Tensor weight, Tensor? bias = null,
(IntPtr)pdilation, dilationArray.Length,
groups);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
}
diff --git a/src/TorchSharp/NN/Convolution/Conv2D.cs b/src/TorchSharp/NN/Convolution/Conv2D.cs
index 2f6ed3f04..022e4bb6f 100644
--- a/src/TorchSharp/NN/Convolution/Conv2D.cs
+++ b/src/TorchSharp/NN/Convolution/Conv2D.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.PInvoke.NativeMethods;
@@ -12,9 +13,37 @@ namespace Modules
{
public sealed class Conv2d : Convolution
{
- internal Conv2d(long in_channels, long out_channels, (long, long) kernel_size, (long, long) stride, (long, long)? padding, Padding? padding_type, (long, long) dilation, long groups = 1, bool bias = true, PaddingModes padding_mode = PaddingModes.Zeros, torch.Device? device = null, ScalarType? dtype = null)
- : base(nameof(Conv2d), in_channels, out_channels, new[] { kernel_size.Item1, kernel_size.Item2 }, new[] { stride.Item1, stride.Item2 }, padding.HasValue ? new[] { padding.Value.Item1, padding.Value.Item2 } : null, padding_type, new[] { dilation.Item1, dilation.Item2 }, false, new[] { 0L, 0L }, groups, bias, padding_mode, device, dtype) { }
+
+ internal Conv2d(IntPtr handle, IntPtr boxedHandle, long input_channels) : base(handle, boxedHandle, input_channels) { }
+ internal Conv2d(IntPtr handle, IntPtr boxedHandle, long input_channels, long in_channels, long out_channels, long kernelSize, long padding, long stride = 1, long dilation = 1, PaddingModes padding_mode = PaddingModes.Zeros, long groups = 1, bool bias = true)
+ : base(handle, boxedHandle, input_channels)
+ {
+ _dimension = 2; //because is conv 2D; 2 dimension
+ _in_channel = in_channels;
+ _out_channel = out_channels;
+ _kernel = kernelSize;
+ _stride = stride;
+ _padding = padding;
+ _dilation = dilation;
+ _paddingModes = padding_mode;
+ _groups = groups;
+ _bias = bias;
+ }
+ internal Conv2d(IntPtr handle, IntPtr boxedHandle, long input_channels, long in_channels, long out_channels, (long, long) kernelSize, Padding padding, (long, long)? stride = null, (long, long)? dilation = null, PaddingModes padding_mode = PaddingModes.Zeros, long groups = 1, bool bias = true)
+ : base(handle, boxedHandle, input_channels)
+ {
+ _dimension = 2; //because is conv 2D; 2 dimension
+ _in_channel = in_channels;
+ _out_channel = out_channels;
+ _kernels = kernelSize;
+ _strides = stride;
+ _padding = (long)padding;
+ _dilations = dilation;
+ _paddingModes = padding_mode;
+ _groups = groups;
+ _bias = bias;
+ }
public override Tensor forward(Tensor input)
{
if (!ValidateShape(input, 2))
@@ -54,7 +83,21 @@ public static partial class nn
///
public static Conv2d Conv2d(long in_channels, long out_channels, long kernel_size, long stride = 1, long padding = 0, long dilation = 1, PaddingModes padding_mode = PaddingModes.Zeros, long groups = 1, bool bias = true, Device? device = null, ScalarType? dtype = null)
{
- return new Conv2d(in_channels, out_channels, (kernel_size, kernel_size), (stride, stride), (padding, padding), null, (dilation, dilation), groups, bias, padding_mode, device, dtype);
+ var res = THSNN_Conv2d_ctor(in_channels, out_channels, kernelSize, stride, padding, dilation, (long)padding_mode, groups, bias, out var boxedHandle);
+ if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+
+ return new Conv2d(res, boxedHandle, in_channels) {
+ _in_channel = in_channels,
+ _out_channel = out_channels,
+ _kernel = kernelSize,
+ _stride = stride,
+ _padding = padding,
+ _dilation = dilation,
+ _paddingModes = padding_mode,
+ _groups = groups,
+ _bias = bias
+ }.MoveModule(device, dtype);
+ //return conv2d.MoveModule(device, dtype);
}
///
@@ -78,7 +121,19 @@ public static Conv2d Conv2d(long in_channels, long out_channels, (long, long) ke
padding ??= (0, 0);
dilation ??= (1, 1);
- return new Conv2d(in_channels, out_channels, kernel_size, stride.Value, padding, null, dilation.Value, groups, bias, padding_mode, device, dtype);
+ var res = THSNN_Conv2d_ctor_1(in_channels, out_channels, kernelSize.Item1, kernelSize.Item2, stride.Value.Item1, stride.Value.Item2, padding.Value.Item1, padding.Value.Item2, dilation.Value.Item1, dilation.Value.Item2, (long)padding_mode, groups, bias, out var boxedHandle);
+ if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ return new Conv2d(res, boxedHandle, in_channels) {
+ _in_channel = in_channels,
+ _out_channel = out_channels,
+ _kernels = kernelSize,
+ _strides = stride,
+ _paddings = padding,
+ _dilations = dilation,
+ _paddingModes = padding_mode,
+ _groups = groups,
+ _bias = bias
+ }.MoveModule(device, dtype);
}
///
@@ -98,7 +153,9 @@ public static Conv2d Conv2d(long in_channels, long out_channels, (long, long) ke
///
public static Conv2d Conv2d(long in_channels, long out_channels, long kernel_size, Padding padding, long stride = 1, long dilation = 1, PaddingModes padding_mode = PaddingModes.Zeros, long groups = 1, bool bias = true, Device? device = null, ScalarType? dtype = null)
{
- return new Conv2d(in_channels, out_channels, (kernel_size, kernel_size), (stride, stride), null, padding, (dilation, dilation), groups, bias, padding_mode, device, dtype);
+ var res = THSNN_Conv2d_ctor(in_channels, out_channels, kernelSize, stride, padding == Padding.Valid ? 0 : -1, dilation, (long)padding_mode, groups, bias, out var boxedHandle);
+ if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ return new Conv2d(res, boxedHandle, in_channels, in_channels, out_channels, kernelSize, (long)padding, stride, dilation, padding_mode, groups, bias).MoveModule(device, dtype);
}
///
@@ -121,7 +178,10 @@ public static Conv2d Conv2d(long in_channels, long out_channels, (long, long) ke
stride ??= (1, 1);
dilation ??= (1, 1);
- return new Conv2d(in_channels, out_channels, kernel_size, stride.Value, null, padding, dilation.Value, groups, bias, padding_mode, device, dtype);
+ var res = THSNN_Conv2d_ctor_1(in_channels, out_channels, kernelSize.Item1, kernelSize.Item2, stride.Value.Item1, stride.Value.Item2, padding == Padding.Valid ? 0 : -1, 0, dilation.Value.Item1, dilation.Value.Item2, (long)padding_mode, groups, bias, out var boxedHandle);
+ if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+
+ return new Conv2d(res, boxedHandle, in_channels, in_channels, out_channels, kernelSize, padding,stride, dilation, padding_mode ,groups,bias).MoveModule(device, dtype);
}
public static partial class functional
@@ -156,6 +216,7 @@ public static Tensor conv2d(Tensor input, Tensor weight, Tensor? bias = null,
(IntPtr)pdilation, dilation.Length,
groups);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
}
diff --git a/src/TorchSharp/NN/Convolution/Conv3D.cs b/src/TorchSharp/NN/Convolution/Conv3D.cs
index d98ca6855..fb971b426 100644
--- a/src/TorchSharp/NN/Convolution/Conv3D.cs
+++ b/src/TorchSharp/NN/Convolution/Conv3D.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.PInvoke.NativeMethods;
@@ -151,6 +152,7 @@ public static Tensor conv3d(Tensor input, Tensor weight, Tensor? bias = null,
(IntPtr)pdilation, dilation.Length,
groups);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
}
diff --git a/src/TorchSharp/NN/Convolution/ConvTranspose1D.cs b/src/TorchSharp/NN/Convolution/ConvTranspose1D.cs
index a4c886585..d186b52ba 100644
--- a/src/TorchSharp/NN/Convolution/ConvTranspose1D.cs
+++ b/src/TorchSharp/NN/Convolution/ConvTranspose1D.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.PInvoke.NativeMethods;
@@ -87,6 +88,7 @@ public static Tensor conv_transpose1d(Tensor input, Tensor weight, Tensor? bias
(IntPtr)pdilation, dilations.Length,
groups);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
}
diff --git a/src/TorchSharp/NN/Convolution/ConvTranspose2D.cs b/src/TorchSharp/NN/Convolution/ConvTranspose2D.cs
index 02aa4eb06..bc1cee3b3 100644
--- a/src/TorchSharp/NN/Convolution/ConvTranspose2D.cs
+++ b/src/TorchSharp/NN/Convolution/ConvTranspose2D.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.PInvoke.NativeMethods;
@@ -114,6 +115,7 @@ public static Tensor conv_transpose2d(Tensor input, Tensor weight, Tensor? bias
(IntPtr)pdilation, dilation.Length,
groups);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
}
diff --git a/src/TorchSharp/NN/Convolution/ConvTranspose3D.cs b/src/TorchSharp/NN/Convolution/ConvTranspose3D.cs
index 6d9604f5b..dbc3ffc23 100644
--- a/src/TorchSharp/NN/Convolution/ConvTranspose3D.cs
+++ b/src/TorchSharp/NN/Convolution/ConvTranspose3D.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.PInvoke.NativeMethods;
@@ -111,6 +112,7 @@ public static Tensor conv_transpose3d(Tensor input, Tensor weight, Tensor? bias
(IntPtr)pdilation, dilation.Length,
groups);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
}
diff --git a/src/TorchSharp/NN/CosineSimilarity.cs b/src/TorchSharp/NN/CosineSimilarity.cs
index e4b8ea04c..ae41ddbb6 100644
--- a/src/TorchSharp/NN/CosineSimilarity.cs
+++ b/src/TorchSharp/NN/CosineSimilarity.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.PInvoke.NativeMethods;
@@ -22,7 +23,10 @@ internal CosineSimilarity(long dim = 1, double eps = 1e-8) : base(nameof(CosineS
public override Tensor forward(Tensor input1, Tensor input2)
{
- return torch.nn.functional.cosine_similarity(input1, input2, this.dim, this.eps);
+ var res = THSNN_CosineSimilarity_forward(handle, input1.Handle, input2.Handle);
+ if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res= AutocastMode.AutoCast(res, ScalarType.Float32);
+ return new Tensor(res);
}
public long dim { get; set; }
@@ -42,7 +46,10 @@ public static partial class nn
///
public static CosineSimilarity CosineSimilarity(long dim = 1, double eps = 1e-8)
{
- return new CosineSimilarity(dim, eps);
+ var handle = THSNN_CosineSimilarity_ctor(dim, eps, out var boxedHandle);
+ if (handle == IntPtr.Zero) { torch.CheckForErrors(); }
+ handle = AutocastMode.AutoCast(handle, ScalarType.Float32);
+ return new CosineSimilarity(handle, boxedHandle);
}
public static partial class functional
diff --git a/src/TorchSharp/NN/Dropout2d.cs b/src/TorchSharp/NN/Dropout2d.cs
index c0d8f20e5..8f33b2927 100644
--- a/src/TorchSharp/NN/Dropout2d.cs
+++ b/src/TorchSharp/NN/Dropout2d.cs
@@ -25,8 +25,14 @@ public override Tensor forward(Tensor input)
return torch.nn.functional.dropout2d(input, this.p, this.training, this.inplace);
}
- public bool inplace { get; set; }
- public double p { get; set;}
+ // Rather than spending cycles only to discover that this module has neither
+ // parameters nor buffers, just shortcut the move completely.
+ protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking) => this;
+ protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking) => this;
+ protected internal override nn.Module _to(ScalarType dtype, bool non_blocking) => this;
+
+ internal bool inplace; //Set internal accesibility for PrintModule
+ internal double p; //Set internal accesibility for PrintModule
}
}
diff --git a/src/TorchSharp/NN/Linear.cs b/src/TorchSharp/NN/Linear.cs
index c7b54a1cf..bb5f6c9f3 100644
--- a/src/TorchSharp/NN/Linear.cs
+++ b/src/TorchSharp/NN/Linear.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using static TorchSharp.PInvoke.NativeMethods;
@@ -12,20 +13,25 @@ namespace TorchSharp
namespace Modules
{
+ public class LinearInfo
+ {
+ public long InFeatures { get; }
+ public long OutFeatures { get; }
+ public LinearInfo(long inFeatures, long outFeatures)
+ {
+ InFeatures = inFeatures;
+ OutFeatures = outFeatures;
+ }
+ }
public sealed class Linear : torch.nn.Module
{
- const string WeightComponentName = nameof(weight);
- const string BiasComponentName = nameof(bias);
-
- internal Linear(Parameter weight, Parameter? bias = null) : base(nameof(Linear))
+ public LinearInfo linearInfo;
+ /*internal Linear(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle)
{
- this.in_features = weight.shape[1];
- this.out_features = weight.shape[0];
-
- this.weight = weight;
- if (bias is not null) {
- this.bias = bias;
- }
+ }*/
+ internal Linear(IntPtr handle, IntPtr boxedHandle, long inFeat, long outFeat) : base(handle, boxedHandle)
+ {
+ linearInfo = new LinearInfo(inFeat, outFeat);
}
internal Linear(long inputSize, long outputSize, bool hasBias = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Linear))
@@ -47,7 +53,10 @@ internal Linear(long inputSize, long outputSize, bool hasBias = true, Device? de
public override Tensor forward(Tensor tensor)
{
- return torch.nn.functional.linear(tensor, _weight!, _bias);
+ //tensor.handle = Amp.AMPManager.GetInstance().AutoCast(tensor.handle); //WARNING should be here???? Research
+ var res = THSNN_Linear_forward(handle, tensor.Handle);
+ if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ return new Tensor(res);
}
protected override void Dispose(bool disposing)
@@ -141,14 +150,7 @@ public static Linear Linear(long inputSize, long outputSize, bool hasBias = true
return new Linear(inputSize, outputSize, hasBias, device, dtype);
}
- ///
- /// Create a Linear module with the given weights and bias.
- ///
- /// The linear weight attribute.
- /// The additive linear bias. Optional.
- public static Linear Linear(Parameter weight, Parameter? bias = null)
- {
- return new Linear(weight, bias);
+ return new Linear(res, boxedHandle, inputSize, outputSize).MoveModule(device, dtype);
}
public static partial class functional
@@ -165,6 +167,7 @@ public static Tensor linear(Tensor input, Tensor weights, Tensor? bias = null)
IntPtr bPtr = bias?.Handle ?? IntPtr.Zero;
var res = THSNN_functional_linear(input.Handle, weights.Handle, bPtr);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
}
diff --git a/src/TorchSharp/NN/Losses.cs b/src/TorchSharp/NN/Losses.cs
index 5e514bef5..9aae89088 100644
--- a/src/TorchSharp/NN/Losses.cs
+++ b/src/TorchSharp/NN/Losses.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using static TorchSharp.PInvoke.NativeMethods;
@@ -365,6 +366,7 @@ public static Tensor binary_cross_entropy_with_logits(Tensor input, Tensor targe
{
var res = THSNN_binary_cross_entropy_with_logits(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, (long)reduction, pos_weights?.Handle ?? IntPtr.Zero);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -435,6 +437,7 @@ public static Tensor cosine_embedding_loss(Tensor input1, Tensor input2, Tensor
{
var res = THSNN_cosine_embedding_loss(input1.Handle, input2.Handle, target.Handle, margin, (long)reduction);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -514,6 +517,7 @@ public static Tensor multi_label_margin_loss(Tensor input, Tensor target, Reduct
{
var res = THSNN_multilabel_margin_loss(input.Handle, target.Handle, (long)reduction);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -547,6 +551,7 @@ public static Tensor multi_margin_loss(Tensor input, Tensor target, int p = 1, d
IntPtr h = (weight is null) ? IntPtr.Zero : weight.Handle;
var res = THSNN_multi_margin_loss(input.Handle, target.Handle, p, margin, h, (long)reduction);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -561,6 +566,7 @@ public static Tensor mse_loss(Tensor input, Tensor target, Reduction reduction =
{
var res = THSNN_mse_loss(input.Handle, target.Handle, (long)reduction);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -620,6 +626,7 @@ public static Tensor kl_div(Tensor input, Tensor target, bool log_target = true,
{
var res = THSNN_kl_div_loss(input.Handle, target.Handle, (long)reduction, log_target);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -744,6 +751,7 @@ public override Tensor forward(Tensor input, Tensor target)
var ii = ignore_index.HasValue ? ignore_index.Value : -100;
var res = THSNN_cross_entropy(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, ii, ignore_index.HasValue, (long)reduction, label_smoothing);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+
return new Tensor(res);
}
@@ -776,6 +784,7 @@ public override Tensor forward(Tensor input, Tensor target)
{
var res = THSNN_binary_cross_entropy_with_logits(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, (long)reduction, pos_weights?.Handle ?? IntPtr.Zero);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -793,6 +802,7 @@ public override Tensor forward(Tensor input1, Tensor input2, Tensor target)
{
var res = THSNN_cosine_embedding_loss(input1.Handle, input2.Handle, target.Handle, margin, (long)reduction);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -829,6 +839,7 @@ public override Tensor forward(Tensor input, Tensor target)
{
var res = THSNN_hinge_embedding_loss(input.Handle, target.Handle, margin, (long)reduction);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -863,6 +874,7 @@ public override Tensor forward(Tensor input1, Tensor input2, Tensor target)
{
var res = THSNN_margin_ranking_loss(input1.Handle, input2.Handle, target.Handle, margin, (long)reduction);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -942,6 +954,7 @@ public override Tensor forward(Tensor input, Tensor target)
{
var res = THSNN_l1_loss(input.Handle, target.Handle, (long)reduction);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
}
@@ -956,6 +969,7 @@ public override Tensor forward(Tensor input, Tensor target)
{
var res = THSNN_nll_loss(input.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, (long)reduction);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
}
@@ -973,6 +987,7 @@ public override Tensor forward(Tensor input, Tensor target)
{
var res = THSNN_poisson_loss(input.Handle, target.Handle, log_input, full, eps, (long)reduction);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -1046,6 +1061,7 @@ public override Tensor forward(Tensor input, Tensor target)
{
var res = THSNN_smooth_l1_loss(input.Handle, target.Handle, (long)reduction, beta);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -1062,6 +1078,7 @@ public override Tensor forward(Tensor input, Tensor target)
{
var res = THSNN_soft_margin_loss(input.Handle, target.Handle, (long)reduction);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
}
@@ -1080,6 +1097,7 @@ public override Tensor forward(Tensor anchor, Tensor positive, Tensor negative)
{
var res = THSNN_triplet_margin_loss(anchor.Handle, positive.Handle, negative.Handle, margin, p, eps, swap, (long)reduction);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
diff --git a/src/TorchSharp/NN/Module.cs b/src/TorchSharp/NN/Module.cs
index 498bda549..757c822c9 100644
--- a/src/TorchSharp/NN/Module.cs
+++ b/src/TorchSharp/NN/Module.cs
@@ -754,6 +754,8 @@ public virtual void register_buffer(string name, Tensor tensor, bool persistent
if (!_internal_buffers.TryAdd(name, (tensor, persistent)))
throw new InvalidOperationException($"Tensor {name} is already registered.");
+
+
}
///
@@ -773,6 +775,13 @@ public virtual void register_parameter(string name, Parameter param)
if (!_internal_params.TryAdd(name, param))
throw new InvalidOperationException($"Parameter {name} is already registered.");
+
+ /*if (is_autocast_cache_enabled()) {
+ if (is_autocast_gpu_enabled())
+ param = param.to(get_autocast_dtype(CUDA)).AsParameter();
+ if (is_autocast_cpu_enabled())
+ param = param.to(get_autocast_dtype(CPU)).AsParameter();
+ }*/
}
///
@@ -813,11 +822,29 @@ public virtual void register_module(string name, Module submodule)
}
submodule.RegisterComponents();
-
+ /*if (!is_autocast_cache_enabled()) {
+ _internal_submodules.Add(name, submodule);
+ return;
+ }
+ if (is_autocast_gpu_enabled())
+ submodule = submodule.to(get_autocast_dtype(CUDA));
+ if (is_autocast_cpu_enabled())
+ submodule = submodule.to(get_autocast_dtype(CPU));
+ */
_internal_submodules.Add(name, submodule);
}
}
+ public virtual void unregister_module(string name)
+ {
+ if (_internal_submodules.ContainsKey(name))
+ _internal_submodules.Remove(name);
+ }
+ public virtual void unregister_module(Module module)
+ {
+ unregister_module(module.GetName());
+ }
+
protected void ConditionallyRegisterParameter(string name, Tensor value)
{
ConditionallyRegisterParameter(name, value as Parameter);
@@ -1122,7 +1149,9 @@ protected virtual void RegisterComponents()
_areComponentsRegistered = true;
}
- protected static (Device device, ScalarType dtype) GetDefaultDeviceAndType(Device? device = null, ScalarType? dtype = null)
+
+
+ protected static (Device device, ScalarType dtype) GetDefaultDeviceAndType(Device device = null, ScalarType? dtype = null)
{
if (!dtype.HasValue)
dtype = get_default_dtype();
@@ -1400,6 +1429,10 @@ public TResult call(T input)
input = modified;
}
+ /*if (is_autocast_cache_enabled()) { //Should i cast this for better managment???
+ if(input is Tensor)
+ }*/
+
var result = forward(input);
// Call post-hooks, if available.
diff --git a/src/TorchSharp/NN/Normalization/Functional.cs b/src/TorchSharp/NN/Normalization/Functional.cs
index cd1d08200..1ccbebdf2 100644
--- a/src/TorchSharp/NN/Normalization/Functional.cs
+++ b/src/TorchSharp/NN/Normalization/Functional.cs
@@ -102,6 +102,24 @@ public static Tensor layer_norm(Tensor input, long[] normalized_shape, Tensor? w
return new Tensor(res);
}
+ ///
+ /// Applies Local Normalization.
+ ///
+ public static Tensor local_response_norm(Tensor input, long size, double alpha = 0.0001, double beta = 0.75, double k = 1.0)
+ {
+ var res = THSNN_local_response_norm(input.Handle, size, alpha, beta, k);
+ if (res == IntPtr.Zero)
+ torch.CheckForErrors();
+ return new Tensor(res);
+ }
+
+ public static Tensor normalize(Tensor input, float p=2.0f, long dim=1, float eps= 1e-12f, Tensor output = null)
+ {
+ var res = THSNN_normalize(input.Handle, p, dim, eps, out _);
+ if (res == IntPtr.Zero)
+ torch.CheckForErrors();
+ return new Tensor(res);
+ }
}
}
}
diff --git a/src/TorchSharp/NN/Normalization/GroupNorm.cs b/src/TorchSharp/NN/Normalization/GroupNorm.cs
index 3e9e1ad32..e16d1109c 100644
--- a/src/TorchSharp/NN/Normalization/GroupNorm.cs
+++ b/src/TorchSharp/NN/Normalization/GroupNorm.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using static TorchSharp.PInvoke.NativeMethods;
@@ -33,9 +34,11 @@ internal GroupNorm(long num_groups, long num_channels, double eps, bool affine,
public override Tensor forward(Tensor tensor)
{
- if (tensor.Dimensions < 3)
- throw new ArgumentException($"Invalid number of dimensions for GroupNorm argument: {tensor.Dimensions}");
- return F.group_norm(tensor, num_groups, weight, bias, eps);
+ if (tensor.Dimensions < 3) throw new ArgumentException($"Invalid number of dimensions for GroupNorm argument: {tensor.Dimensions}");
+ var res = THSNN_GroupNorm_forward(handle.DangerousGetHandle(), tensor.Handle);
+ if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res= AutocastMode.AutoCast(res, ScalarType.Float32);
+ return new Tensor(res);
}
protected override void Dispose(bool disposing)
@@ -125,7 +128,12 @@ public static partial class nn
///
public static GroupNorm GroupNorm(long num_groups, long num_channels, double eps = 1e-05, bool affine = true, Device? device = null, ScalarType? dtype = null)
{
- return new GroupNorm(num_groups, num_channels, eps, affine, device, dtype);
+ unsafe {
+ var handle = THSNN_GroupNorm_ctor(num_groups, num_channels, eps, affine, out var boxedHandle);
+ if (handle == IntPtr.Zero) { torch.CheckForErrors(); }
+ handle= AutocastMode.AutoCast(handle, ScalarType.Float32);
+ return new GroupNorm(handle, boxedHandle).MoveModule(device, dtype);
+ }
}
}
}
diff --git a/src/TorchSharp/NN/Normalization/LayerNorm.cs b/src/TorchSharp/NN/Normalization/LayerNorm.cs
index 6c8458c1d..77e751c85 100644
--- a/src/TorchSharp/NN/Normalization/LayerNorm.cs
+++ b/src/TorchSharp/NN/Normalization/LayerNorm.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using static TorchSharp.PInvoke.NativeMethods;
@@ -18,8 +19,8 @@ namespace Modules
///
public sealed class LayerNorm : torch.nn.Module
{
- const string WeightComponentName = nameof(weight);
- const string BiasComponentName = nameof(bias);
+ internal long[] _normalized_shape;
+ internal double _eps;
internal LayerNorm(long[] normalized_shape, double eps, bool elementwise_affine, bool bias, Device? device, ScalarType? dtype) : base(nameof(LayerNorm))
{
@@ -30,9 +31,11 @@ internal LayerNorm(long[] normalized_shape, double eps, bool elementwise_affine,
if (elementwise_affine)
{
weight = Parameter(torch.empty(normalized_shape, dtype, device));
+ //weight.handle = AutocastMode.AutoCast(weight.handle, ScalarType.Float32); //This is correct???
if (bias)
{
this.bias = Parameter(torch.empty(normalized_shape, dtype, device));
+ //bias.handle = AutocastMode.AutoCast(bias.handle, ScalarType.Float32); //This is correct???
}
}
diff --git a/src/TorchSharp/NN/PairwiseDistance.cs b/src/TorchSharp/NN/PairwiseDistance.cs
index b0d6ba627..7a8cb79d4 100644
--- a/src/TorchSharp/NN/PairwiseDistance.cs
+++ b/src/TorchSharp/NN/PairwiseDistance.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.PInvoke.NativeMethods;
@@ -40,7 +41,10 @@ public static partial class nn
{
public static PairwiseDistance PairwiseDistance(double p = 2.0, double eps = 1e-6, bool keepdim = false)
{
- return new PairwiseDistance(p, eps, keepdim);
+ var handle = THSNN_PairwiseDistance_ctor(p, eps, keep_dim, out var boxedHandle);
+ if (handle == IntPtr.Zero) { torch.CheckForErrors(); }
+ handle = AutocastMode.AutoCast(handle, ScalarType.Float32);
+ return new PairwiseDistance(handle, boxedHandle);
}
public static partial class functional
diff --git a/src/TorchSharp/NN/Parameter.cs b/src/TorchSharp/NN/Parameter.cs
index 86a7f29e5..897e99f97 100644
--- a/src/TorchSharp/NN/Parameter.cs
+++ b/src/TorchSharp/NN/Parameter.cs
@@ -39,6 +39,20 @@ public Parameter(Tensor data, bool requires_grad = true) :
internal Parameter(System.IntPtr handle) : base(handle)
{
}
+
+ ///
+ /// For prevent cast as torch.Tensor i provided the data method for get Tensor.
+ /// https://github.com/ultralytics/ultralytics/blob/dcde8bd23d12bbb4867ebf45f936dd37c2445974/ultralytics/nn/modules/conv.py#L78
+ ///
+ ///
+ public torch.Tensor data {
+ get {
+ return new Tensor(base.handle);
+ }
+ set {
+ handle = value.handle;
+ }
+ }
};
}
diff --git a/src/TorchSharp/NN/Recurrent/GRUCell.cs b/src/TorchSharp/NN/Recurrent/GRUCell.cs
index cea14644c..610762542 100644
--- a/src/TorchSharp/NN/Recurrent/GRUCell.cs
+++ b/src/TorchSharp/NN/Recurrent/GRUCell.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using static TorchSharp.PInvoke.NativeMethods;
@@ -106,6 +107,7 @@ public static GRUCell GRUCell(long inputSize, long hiddenSize, bool bias = true,
{
var res = THSNN_GRUCell_ctor(inputSize, hiddenSize, bias, out var boxedHandle);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new GRUCell(res, boxedHandle).MoveModule(device, dtype);
}
}
diff --git a/src/TorchSharp/NN/Recurrent/LSTMCell.cs b/src/TorchSharp/NN/Recurrent/LSTMCell.cs
index d74ddfd60..44f6e5bbc 100644
--- a/src/TorchSharp/NN/Recurrent/LSTMCell.cs
+++ b/src/TorchSharp/NN/Recurrent/LSTMCell.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using static TorchSharp.PInvoke.NativeMethods;
@@ -108,6 +109,7 @@ public static LSTMCell LSTMCell(long inputSize, long hiddenSize, bool bias = tru
{
var res = THSNN_LSTMCell_ctor(inputSize, hiddenSize, bias, out var boxedHandle);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new LSTMCell(res, boxedHandle).MoveModule(device, dtype);
}
}
diff --git a/src/TorchSharp/NN/Recurrent/RNNCell.cs b/src/TorchSharp/NN/Recurrent/RNNCell.cs
index 5748e1f12..05bf7088b 100644
--- a/src/TorchSharp/NN/Recurrent/RNNCell.cs
+++ b/src/TorchSharp/NN/Recurrent/RNNCell.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using static TorchSharp.PInvoke.NativeMethods;
@@ -112,6 +113,7 @@ public static RNNCell RNNCell(long inputSize, long hiddenSize, NonLinearities no
{
var res = THSNN_RNNCell_ctor(inputSize, hiddenSize, (long)nonLinearity, bias, out var boxedHandle);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new RNNCell(res, boxedHandle).MoveModule(device, dtype);
}
}
diff --git a/src/TorchSharp/NN/Sequential.cs b/src/TorchSharp/NN/Sequential.cs
index 7ddc14fa7..c7b7faa65 100644
--- a/src/TorchSharp/NN/Sequential.cs
+++ b/src/TorchSharp/NN/Sequential.cs
@@ -32,7 +32,6 @@ public Sequential append(string name, torch.nn.IModule module)
Add(name, module);
return this;
}
-
internal void Add(string name, torch.nn.IModule sm)
{
var submodule = (torch.nn.Module)sm;
@@ -52,6 +51,12 @@ public Sequential append(torch.nn.IModule module)
return this;
}
+ public Sequential append(IList> modules)
+ {
+ for (int i = 0; i < modules.Count; i++)
+ Add(_modules.Count.ToString(), modules[i]);
+ return this;
+ }
internal void Add(torch.nn.IModule module)
{
var name = _modules.Count.ToString();
diff --git a/src/TorchSharp/NN/Vision.cs b/src/TorchSharp/NN/Vision.cs
index 5dd5fe6e2..654bef049 100644
--- a/src/TorchSharp/NN/Vision.cs
+++ b/src/TorchSharp/NN/Vision.cs
@@ -1,5 +1,7 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using System.Linq;
+using TorchSharp.Amp;
using static TorchSharp.PInvoke.NativeMethods;
#nullable enable
@@ -164,8 +166,17 @@ public static Tensor pad(Tensor input, long pad, PaddingModes mode = PaddingMode
public static Tensor grid_sample(Tensor input, Tensor grid, GridSampleMode mode = GridSampleMode.Bilinear, GridSamplePaddingMode padding_mode = GridSamplePaddingMode.Zeros, bool? align_corners = null)
{
byte ac = (byte)((align_corners.HasValue) ? (align_corners.Value ? 1 : 2) : 0);
+ if (AutocastMode.IsAutocastEnabled()) {
+ var sts = new[] { input.dtype, grid.dtype };
+ if (sts.All(x => x == ScalarType.Float16))
+ (input.handle, grid.handle) = AutocastMode.AutoCast(input.handle, grid.handle, ScalarType.Float16);
+ if (sts.Any(x => x == ScalarType.Float32))
+ (input.handle, grid.handle) = AutocastMode.AutoCast(input.handle, grid.handle, ScalarType.Float32);
+ }
+
var res = THSNN_grid_sample(input.Handle, grid.Handle, (byte)mode, (byte)padding_mode, ac);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+
return new Tensor(res);
}
diff --git a/src/TorchSharp/Optimizers/Optimizer.cs b/src/TorchSharp/Optimizers/Optimizer.cs
index 9c40f0765..93cc48d0f 100644
--- a/src/TorchSharp/Optimizers/Optimizer.cs
+++ b/src/TorchSharp/Optimizers/Optimizer.cs
@@ -21,6 +21,8 @@ public static partial class optim
///
public abstract partial class Optimizer : IDisposable
{
+ internal Tensor grad_scale;
+ internal Tensor found_inf;
///
/// Class wrapping PyTorch's optimzer object reference.
///
@@ -85,6 +87,9 @@ public void Dispose()
protected virtual void Dispose(bool disposing)
{
if (disposing && handle != null && !handle.IsInvalid) {
+
+ grad_scale?.Dispose();
+ found_inf?.Dispose();
handle.Dispose();
handle.SetHandleAsInvalid();
}
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSAmp.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSAmp.cs
new file mode 100644
index 000000000..cfc9cda91
--- /dev/null
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSAmp.cs
@@ -0,0 +1,46 @@
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+#nullable enable
+using System;
+using System.Collections.Generic;
+using System.Runtime.InteropServices;
+
+namespace TorchSharp.PInvoke
+{
+ internal static partial class NativeMethods
+ {
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSAmp_amp_foreach_non_finite_check_and_unscale_(IntPtr tensors, long tLength, IntPtr found_inf, IntPtr inv_scale);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSAmp_amp_update_scale_(IntPtr self, IntPtr growth_tracker, IntPtr found_inf, double scale_growth_factor, double scale_backoff_factor, long growth_interval);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSAmp_amp_update_scale_out(IntPtr outt,IntPtr self, IntPtr growth_tracker, IntPtr found_inf, double scale_growth_factor, double scale_backoff_factor, long growth_interval);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSAmp_amp_update_scale_outf(IntPtr self,IntPtr growth_tracker, IntPtr found_inf, double scale_growth_factor, double scale_backoff_factor, long growth_interval, IntPtr outt);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSAMP_amp_update_scale(IntPtr self,IntPtr growth_tracker, IntPtr found_inf, double scale_growth_factor, double scale_backoff_factor, long growth_interval, out IntPtr sec);
+ [DllImport("LibTorchSharp")]
+ internal static extern bool THSAmp_is_torch_function_mode_enabled();
+ [DllImport("LibTorchSharp")]
+ internal static extern bool THSAmp_is_autocast_cache_enabled();
+ [DllImport("LibTorchSharp")]
+ internal static extern bool THSAmp_is_autocast_available(int device_type);
+ [DllImport("LibTorchSharp")]
+ internal static extern bool THSAmp_is_autocast_enabled(int device_type);
+ [DllImport("LibTorchSharp")]
+ internal static extern sbyte THSAmp_get_autocast_dtype(int device_type);
+ [DllImport("LibTorchSharp")]
+ internal static extern int THSAmp_autocast_increment_nesting();
+ [DllImport("LibTorchSharp")]
+ internal static extern int THSAmp_autocast_decrement_nesting();
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSAmp_set_autocast_enabled(int device_type, bool enabled);
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSAmp_set_autocast_cache_enabled(bool enabled);
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSAmp_set_autocast_dtype(int device_type, sbyte dtype);
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSAmp_clear_autocast_cache();
+
+
+ }
+}
\ No newline at end of file
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSCuda.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSCuda.cs
index 8920a141a..d455f5746 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSCuda.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSCuda.cs
@@ -41,5 +41,18 @@ internal static partial class NativeMethods
internal static extern bool THSBackend_cuda_get_enable_math_sdp();
[DllImport("LibTorchSharp")]
internal static extern void THSBackend_cuda_set_enable_math_sdp([MarshalAs(UnmanagedType.U1)] bool flag);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern int THSCuda_get_major_compute_capability(int device=0);
+ [DllImport("LibTorchSharp")]
+ internal static extern int THSCuda_get_minor_compute_capability(int device = 0);
+ [DllImport("LibTorchSharp")]
+ internal static extern int THSCuda_get_device_count(ref int count);
+ [DllImport("LibTorchSharp")]
+ internal static extern int THSCuda_get_free_total(int device, ref int id, ref ulong free, ref ulong total);
+ [DllImport("LibTorchSharp")]
+ internal static extern ulong THSCuda_get_total_memory(int device);
+ [DllImport("LibTorchSharp")]
+ internal static extern ulong THSCuda_get_global_total_memory(int device);
}
}
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs
index fd24a26c4..ebf11f326 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs
@@ -551,9 +551,207 @@ internal static extern IntPtr THSNN_custom_module(
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_ConvTranspose2d_ctor_1(long inputChannel, long outputChannel, long kernelSizeX, long kernelSizeY, long strideX, long strideY, long paddingX, long paddingY, long outputPaddingX, long outputPaddingY, long dilationX, long dilationY, long paddingMode, long groups, [MarshalAs(UnmanagedType.U1)] bool bias, out IntPtr pBoxedModule);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_BatchNorm2d_forward(IntPtr module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_BatchNorm2d_bias(torch.nn.Module.HType module);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSNN_BatchNorm2d_set_bias(torch.nn.Module.HType module, IntPtr bias);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_BatchNorm2d_weight(torch.nn.Module.HType module);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSNN_BatchNorm2d_set_weight(torch.nn.Module.HType module, IntPtr weight);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSNN_BatchNorm2d_reset_stats(torch.nn.Module.HType module);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_BatchNorm2d_get_mean(torch.nn.Module.HType module);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_BatchNorm2d_get_var(torch.nn.Module.HType module);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_BatchNorm2d_get_batches(torch.nn.Module.HType module);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSNN_BatchNorm2d_set_mean(torch.nn.Module.HType module, IntPtr weight);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSNN_BatchNorm2d_set_var(torch.nn.Module.HType module, IntPtr weight);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_BatchNorm2d_ctor(long features, double eps, double momentum, [MarshalAs(UnmanagedType.U1)] bool affine, [MarshalAs(UnmanagedType.U1)] bool track_running_stats, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_BatchNorm3d_forward(IntPtr module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_BatchNorm3d_bias(torch.nn.Module.HType module);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSNN_BatchNorm3d_set_bias(torch.nn.Module.HType module, IntPtr bias);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_BatchNorm3d_weight(torch.nn.Module.HType module);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSNN_BatchNorm3d_set_weight(torch.nn.Module.HType module, IntPtr weight);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSNN_BatchNorm3d_reset_stats(torch.nn.Module.HType module);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_BatchNorm3d_get_mean(torch.nn.Module.HType module);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_BatchNorm3d_get_var(torch.nn.Module.HType module);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_BatchNorm3d_get_batches(torch.nn.Module.HType module);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSNN_BatchNorm3d_set_mean(torch.nn.Module.HType module, IntPtr weight);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSNN_BatchNorm3d_set_var(torch.nn.Module.HType module, IntPtr weight);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_BatchNorm3d_ctor(long features, double eps, double momentum, [MarshalAs(UnmanagedType.U1)] bool affine, [MarshalAs(UnmanagedType.U1)] bool track_running_stats, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_MaxPool1d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_MaxPool1d_forward_with_indices(torch.nn.Module.HType module, IntPtr tensor, out IntPtr indices);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_MaxPool1d_ctor(IntPtr pkernelSize, IntPtr pStrides, IntPtr pPadding, IntPtr pDilation, [MarshalAs(UnmanagedType.U1)] bool ceilMode, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_MaxUnpool3d_forward(torch.nn.Module.HType module, IntPtr tensor, IntPtr indices, IntPtr outSize, int outputSizeLength);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_MaxUnpool3d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr pPadding, int paddingLength, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ELU_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ELU_ctor(double alpha, [MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_GELU_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp", CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true)]
+ internal static extern IntPtr THSNN_GELU_ctor(out IntPtr pBoxedModule, [MarshalAs(UnmanagedType.LPStr)] string approximate);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_GLU_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_GLU_ctor(long dim, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Hardshrink_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Hardshrink_ctor(double lambd, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Hardtanh_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Hardtanh_ctor(double min_val, double max_val, [MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Mish_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Mish_ctor(out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_PReLU_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_PReLU_ctor(long nparams, double init, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_PReLU_weight(torch.nn.Module.HType module);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSNN_PReLU_set_weight(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReLU_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReLU_ctor([MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReLU6_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReLU6_ctor([MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_RReLU_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_RReLU_ctor(double lower, double upper, [MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_scaled_dot_product_attention(IntPtr query, IntPtr key, IntPtr value, IntPtr attention_mask, double p, [MarshalAs(UnmanagedType.U1)] bool casual);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_normalize(IntPtr input, float p, long dim, float eps, out IntPtr output);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_SELU_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_SELU_ctor([MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Sigmoid_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Sigmoid_ctor(out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_SiLU_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_SiLU_ctor(out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Softmax_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Softmax_ctor(long dim, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Softmax2d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Softmax2d_ctor(out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Softmin_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Softmin_ctor(long dim, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Softplus_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_Softplus_ctor(double beta, double threshold, out IntPtr pBoxedModule);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_Softshrink_forward(torch.nn.Module.HType module, IntPtr tensor);
@@ -564,7 +762,220 @@ internal static extern IntPtr THSNN_custom_module(
internal static extern IntPtr THSNN_Threshold_forward(torch.nn.Module.HType module, IntPtr tensor);
[DllImport("LibTorchSharp")]
- internal static extern IntPtr THSNN_Threshold_ctor(double threshold, double value, [MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule);
+ internal static extern IntPtr THSNN_Threshold_ctor(double threshold, double value, [MarshalAs(UnmanagedType.U1)] bool inplace, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_LocalResponseNorm_forward(IntPtr module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_LocalResponseNorm_ctor(long size, double alpha, double beta, double k, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ConstantPad1d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ConstantPad1d_ctor(double value, long padding, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ConstantPad1d_ctor_tuple(double value, long padding_left, long padding_right, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ConstantPad2d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ConstantPad2d_ctor(double value, long padding, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ConstantPad2d_ctor_tuple(double value, long padding_left, long padding_right, long padding_top, long padding_bottom, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ConstantPad3d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ConstantPad3d_ctor(double value, long padding, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ConstantPad3d_ctor_tuple(double value, long padding_left, long padding_right, long padding_top, long padding_bottom, long padding_front, long padding_back, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReflectionPad1d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReflectionPad1d_ctor(long padding, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReflectionPad1d_ctor_tuple(long padding_left, long padding_right, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReflectionPad2d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReflectionPad2d_ctor(long padding, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReflectionPad2d_ctor_tuple(long padding_left, long padding_right, long padding_top, long padding_bottom, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReflectionPad3d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReflectionPad3d_ctor(long padding, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReflectionPad3d_ctor_tuple(long padding_left, long padding_right, long padding_top, long padding_bottom, long padding_front, long padding_back, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReplicationPad1d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReplicationPad1d_ctor(long padding, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReplicationPad1d_ctor_tuple(long padding_left, long padding_right, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReplicationPad2d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReplicationPad2d_ctor(long padding, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReplicationPad2d_ctor_tuple(long padding_left, long padding_right, long padding_top, long padding_bottom, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReplicationPad3d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReplicationPad3d_ctor(long padding, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ReplicationPad3d_ctor_tuple(long padding_left, long padding_right, long padding_top, long padding_bottom, long padding_front, long padding_back, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ZeroPad2d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ZeroPad2d_ctor(long padding, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_ZeroPad2d_ctor_tuple(long padding_left, long padding_right, long padding_top, long padding_bottom, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AdaptiveAvgPool1d_forward(IntPtr module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AdaptiveAvgPool1d_ctor(IntPtr psizes, int length, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AdaptiveAvgPool2d_forward(IntPtr module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AdaptiveAvgPool2d_ctor(IntPtr psizes, int length, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AdaptiveAvgPool3d_forward(IntPtr module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AdaptiveAvgPool3d_ctor(IntPtr psizes, int length, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AdaptiveMaxPool1d_forward(IntPtr module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AdaptiveMaxPool1d_ctor(IntPtr psizes, int length, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AdaptiveMaxPool2d_forward(IntPtr module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AdaptiveMaxPool2d_ctor(IntPtr psizes, int length, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AdaptiveMaxPool3d_forward(IntPtr module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AdaptiveMaxPool3d_ctor(IntPtr psizes, int length, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AvgPool1d_forward(IntPtr module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AvgPool1d_ctor(IntPtr pkernelSize, IntPtr pstrides, IntPtr ppadding, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisor_override, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AvgPool2d_forward(IntPtr module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AvgPool2d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr ppadding, int paddingLength, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisor_override, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AvgPool3d_forward(IntPtr module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_AvgPool3d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr ppadding, int paddingLength, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisor_override, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_FractionalMaxPool2d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_FractionalMaxPool2d_forward_with_indices(torch.nn.Module.HType module, IntPtr tensor, out IntPtr indices);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_FractionalMaxPool2d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pOutputSize, int sizeLength, IntPtr pOutputRatio, int ratioLength, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_FractionalMaxPool3d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_FractionalMaxPool3d_forward_with_indices(torch.nn.Module.HType module, IntPtr tensor, out IntPtr indices);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_FractionalMaxPool3d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pOutputSize, int sizeLength, IntPtr pOutputRatio, int ratioLength, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_LPPool1d_forward(IntPtr module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_LPPool1d_ctor(double norm_type, IntPtr pkernelSize, IntPtr pstrides, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_LPPool2d_forward(IntPtr module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_LPPool2d_ctor(double norm_type, IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_MaxPool2d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_MaxPool2d_forward_with_indices(torch.nn.Module.HType module, IntPtr tensor, out IntPtr indices);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_MaxPool2d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr pPadding, int paddingLength, IntPtr pDilation, int dilationLength, [MarshalAs(UnmanagedType.U1)] bool ceilMode, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_MaxPool3d_forward(torch.nn.Module.HType module, IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_MaxPool3d_forward_with_indices(torch.nn.Module.HType module, IntPtr tensor, out IntPtr indices);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_MaxPool3d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr pPadding, int paddingLength, IntPtr pDilation, int dilationLength, [MarshalAs(UnmanagedType.U1)] bool ceilMode, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_MaxUnpool1d_forward(torch.nn.Module.HType module, IntPtr tensor, IntPtr indices, IntPtr outSize);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_MaxUnpool1d_ctor(IntPtr pkernelSize, IntPtr pStrides, IntPtr pPadding, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_MaxUnpool2d_forward(torch.nn.Module.HType module, IntPtr tensor, IntPtr indices, IntPtr outSize, int outputSizeLength);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSNN_MaxUnpool2d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr pPadding, int paddingLength, out IntPtr pBoxedModule);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSNN_Print_Module(torch.nn.Module.HType module);
}
#pragma warning restore CA2101
}
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSStorage.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSStorage.cs
index 7cf494b7a..bd5b46694 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSStorage.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSStorage.cs
@@ -15,5 +15,15 @@ internal static partial class NativeMethods
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSStorage_data_ptr(IntPtr tensor);
+ /*[DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSStorage_tensor_to_array_int(IntPtr tensor);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSStorage_tensor_to_array_long(IntPtr tensor);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSStorage_tensor_to_array_float(IntPtr tensor);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSStorage_tensor_to_array_double(IntPtr tensor);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSStorage_tensor_to_array_byte(IntPtr tensor);*/
}
}
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
index 65018f5a5..c1c84811d 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
@@ -321,11 +321,14 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_to_device(IntPtr handle, int device_type, int device_index, [MarshalAs(UnmanagedType.U1)] bool copy, [MarshalAs(UnmanagedType.U1)] bool non_blocking);
+ [DllImport("LibTorchSharp")]
+ //internal static extern IntPtr THSTensor_to_type_and_device(IntPtr handle, sbyte scalar_type, int device_type, int device_index, [MarshalAs(UnmanagedType.U1)] bool copy);
+ internal static extern IntPtr THSTensor_to_type_and_device(IntPtr handle, sbyte scalar_type, int device_type, int device_index, [MarshalAs(UnmanagedType.U1)] bool copy, [MarshalAs(UnmanagedType.U1)] bool non_blocking);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_to_type(IntPtr handle, sbyte scalar_type, [MarshalAs(UnmanagedType.U1)] bool copy, [MarshalAs(UnmanagedType.U1)] bool non_blocking);
[DllImport("LibTorchSharp")]
- internal static extern IntPtr THSTensor_to_type_and_device(IntPtr handle, sbyte scalar_type, int device_type, int device_index, [MarshalAs(UnmanagedType.U1)] bool copy, [MarshalAs(UnmanagedType.U1)] bool non_blocking);
+ internal static extern IntPtr THSTensor_to_type_and_device_and_non_blocking(IntPtr handle, sbyte scalar_type, int device_type, int device_index, [MarshalAs(UnmanagedType.U1)] bool non_blocking);
[DllImport("LibTorchSharp")]
internal static extern void THSTensor_set_(IntPtr tensor, IntPtr source);
@@ -412,6 +415,16 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern void THSTensor_index_put_(IntPtr tensor, IntPtr indexStarts, IntPtr indexEnds, IntPtr indexSteps, IntPtr indexTensors, int indicesLength, IntPtr value);
+ /*
+ //NOTE: The index_put and with accumulate need passing to c10::List>()
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSTensor_index_put_accumulate_(IntPtr tensor, IntPtr indexStarts, IntPtr indexEnds, IntPtr indexSteps, IntPtr indexTensors, int indicesLength, IntPtr value, [MarshalAs(UnmanagedType.I1)] bool accumulate);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_index_put(IntPtr tensor, IntPtr indexStarts, IntPtr indexEnds, IntPtr indexSteps, IntPtr indexTensors, int indicesLength, IntPtr value);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_index_put_accumulate(IntPtr tensor, IntPtr indexStarts, IntPtr indexEnds, IntPtr indexSteps, IntPtr indexTensors, int indicesLength, IntPtr value, [MarshalAs(UnmanagedType.I1)] bool accumulate);*/
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_get1(IntPtr handle, long i1);
@@ -489,6 +502,8 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_reshape(IntPtr tensor, IntPtr shape, int length);
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSTensor_resize_(IntPtr tensor, IntPtr shape, int length);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_flatten(IntPtr tensor, long start, long end);
@@ -2188,6 +2203,11 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
internal static extern IntPtr THSTensor_histogram_out_t(IntPtr input, IntPtr bins, IntPtr weight, bool density, out IntPtr hist, out IntPtr bin_edges, out IntPtr r_bin_edges);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_histogram_out_i(IntPtr input, long bins, IntPtr range, int length, IntPtr weight, bool density, out IntPtr hist, out IntPtr bin_edges, out IntPtr r_bin_edges);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_coalesce(IntPtr input);
+ [DllImport("LibTorchSharp")]
+ internal static extern bool THSTensor_is_coalesce(IntPtr input);
}
#pragma warning restore CA2101
}
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorchCuda.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorchCuda.cs
index fc67a88de..531b47d76 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorchCuda.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorchCuda.cs
@@ -19,5 +19,7 @@ internal static partial class NativeMethods
[DllImport("LibTorchSharp")]
internal static extern void THSTorchCuda_synchronize(long device_index);
+
+
}
}
diff --git a/src/TorchSharp/Special.cs b/src/TorchSharp/Special.cs
index 1b568376e..54947f1ab 100644
--- a/src/TorchSharp/Special.cs
+++ b/src/TorchSharp/Special.cs
@@ -1,5 +1,6 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
+using TorchSharp.Amp;
using static TorchSharp.PInvoke.NativeMethods;
#nullable enable
@@ -675,10 +676,11 @@ public static Tensor logit(Tensor input)
///
public static Tensor log_softmax(Tensor input, long dim, ScalarType? dtype = null)
{
- var dt = dtype.HasValue ? dtype.Value : input.dtype;
+ var dt = dtype ?? input.dtype;
var res = THSSpecial_log_softmax(input.Handle, dim, (sbyte)dt);
if (res == IntPtr.Zero)
torch.CheckForErrors();
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -746,6 +748,7 @@ public static Tensor softmax(Tensor input, long dim, ScalarType? dtype = null)
var res = THSSpecial_softmax(input.Handle, dim, (sbyte)dt);
if (res == IntPtr.Zero)
torch.CheckForErrors();
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
diff --git a/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs b/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs
index b306c0cd7..99aef9827 100644
--- a/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs
+++ b/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs
@@ -166,7 +166,7 @@ private static Tensor _tensor_generic(Array rawArray, ReadOnlySpan dimensi
unsafe {
void *ptr = null;
- IntPtr iPtr = (IntPtr)ptr;
+ IntPtr iPtr = (IntPtr)ptr; //Warning: Unused variable
fixed (long* shape = dimensions) {
var handle = THSTensor_new(dataArrayAddr, deleter, (IntPtr)shape, dimensions.Length, origType, (sbyte)dtype.Value, (int)device.type, device.index, requires_grad);
@@ -225,7 +225,7 @@ private static Tensor _tensor_generic(Memory rawArray, ReadOnlySpan
deleters.TryAdd(deleter, deleter); // keep the delegate alive
void *ptr = null;
- IntPtr iPtr = (IntPtr)ptr;
+ IntPtr iPtr = (IntPtr)ptr; //Warning: Unused variable
fixed (long* shape = dimensions) {
var handle = THSTensor_new(dataArrayAddr, deleter, (IntPtr)shape, dimensions.Length, origType, (sbyte)dtype.Value, (int)device.type, device.index, requires_grad);
@@ -243,6 +243,12 @@ private static Tensor _tensor_generic(Memory rawArray, ReadOnlySpan
tensor.rename_(names);
}
+ /*if (!is_autocast_cache_enabled())
+ return tensor;
+ if (is_autocast_gpu_enabled())
+ tensor = tensor.to(get_autocast_gpu_dtype());
+ if (is_autocast_cpu_enabled())
+ tensor = tensor.to(get_autocast_cpu_dtype());*/
return tensor;
}
}
diff --git a/src/TorchSharp/Tensor/Factories/tensor_float.cs b/src/TorchSharp/Tensor/Factories/tensor_float.cs
index 50ef429ab..6b70bd3fc 100644
--- a/src/TorchSharp/Tensor/Factories/tensor_float.cs
+++ b/src/TorchSharp/Tensor/Factories/tensor_float.cs
@@ -3,6 +3,7 @@
using System.Collections.Generic;
using System.Diagnostics.Contracts;
using System.Linq;
+using TorchSharp.Amp;
using static TorchSharp.PInvoke.NativeMethods;
#nullable enable
@@ -18,7 +19,15 @@ public static Tensor tensor(float scalar, Device? device = null, bool requires_g
device = InitializeDevice(device);
var handle = THSTensor_newFloat32Scalar(scalar, (int)device.type, device.index, requires_grad);
if (handle == IntPtr.Zero) { CheckForErrors(); }
- return new Tensor(handle);
+
+
+ //var t = new Tensor(handle).AutoCast();
+ var t = new Tensor(handle);
+ /*if (is_autocast_cache_enabled()) {
+ if (is_autocast_gpu_enabled())
+ return t.to(get_autocast_gpu_dtype()); //this work, but should put that on all tensor factorie...
+ }*/
+ return t;
}
///
diff --git a/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs b/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs
index 53e6facfb..a26dc15b7 100644
--- a/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs
+++ b/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
using System.Linq;
+using TorchSharp.Amp;
using static TorchSharp.PInvoke.NativeMethods;
namespace TorchSharp
@@ -17,6 +18,13 @@ public partial class Tensor
public Tensor tensordot(Tensor b, long[] dims1, long[] dims2)
{
IntPtr res;
+ if (AutocastMode.IsAutocastEnabled()) {
+ var sts = new[] { this.dtype, b.dtype };
+ if (sts.All(x => x == ScalarType.Float16))
+ (handle, b.handle) = AutocastMode.AutoCast(handle, b.handle, ScalarType.Float16);
+ if (sts.Any(x => x == ScalarType.Float32))
+ (handle, b.handle) = AutocastMode.AutoCast(handle, b.handle, ScalarType.Float32);
+ }
unsafe {
fixed (long* pdims1 = dims1, pdims2 = dims2) {
res = THSLinalg_tensordot(Handle, b.Handle,(IntPtr)pdims1, dims1.Length,(IntPtr)pdims2, dims2.Length);
@@ -110,7 +118,19 @@ public Tensor cross(Scalar other, long dim)
if (res == IntPtr.Zero) { CheckForErrors(); }
return new Tensor(res);
}
-
+ public Tensor cross(Tensor other, long dim)
+ {
+ if (AutocastMode.IsAutocastEnabled()) {
+ var sts = new[] { this.dtype, other.dtype};
+ if (sts.All(x => x == ScalarType.Float16))
+ (handle, other.handle)= AutocastMode.AutoCast(handle, other.handle, ScalarType.Float16);
+ if (sts.Any(x => x == ScalarType.Float32))
+ (handle, other.handle) = AutocastMode.AutoCast(handle, other.handle, ScalarType.Float32);
+ }
+ var res = THSTensor_cross(Handle, other.Handle, dim);
+ if (res == IntPtr.Zero) { CheckForErrors(); }
+ return new Tensor(res);
+ }
///
/// Computes the determinant of a square matrix.
///
@@ -171,6 +191,7 @@ public Tensor matmul(Tensor target)
{
var res = THSTensor_matmul(Handle, target.Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
@@ -183,6 +204,7 @@ public Tensor mm(Tensor target)
{
var res = THSTensor_mm(Handle, target.Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
@@ -195,6 +217,7 @@ public Tensor mv(Tensor target)
{
var res = THSTensor_mv(Handle, target.Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
@@ -244,6 +267,13 @@ public Tensor vdot(Tensor target)
public Tensor dot(Tensor target)
{
if (shape.Length != 1 || target.shape.Length != 1 || shape[0] != target.shape[0]) throw new InvalidOperationException("dot arguments must have the same shape.");
+ if (AutocastMode.IsAutocastEnabled()) {
+ var sts = new[] { this.dtype, target.dtype };
+ if (sts.All(x => x == ScalarType.Float16))
+ (handle, target.handle) = AutocastMode.AutoCast(handle, target.handle, ScalarType.Float16);
+ if (sts.Any(x => x == ScalarType.Float32))
+ (handle, target.handle) = AutocastMode.AutoCast(handle, target.handle, ScalarType.Float32);
+ }
var res = THSTensor_dot(Handle, target.Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
return new Tensor(res);
diff --git a/src/TorchSharp/Tensor/Tensor.Math.cs b/src/TorchSharp/Tensor/Tensor.Math.cs
index fb7207638..0fec7e12f 100644
--- a/src/TorchSharp/Tensor/Tensor.Math.cs
+++ b/src/TorchSharp/Tensor/Tensor.Math.cs
@@ -1,6 +1,8 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
#nullable enable
using System;
+using System.Linq;
+using TorchSharp.Amp;
using static TorchSharp.PInvoke.NativeMethods;
namespace TorchSharp
@@ -157,6 +159,7 @@ public Tensor addbmm(Tensor batch1, Tensor batch2, float beta = 1, float alpha =
var res = THSTensor_addbmm(Handle, batch1.Handle, batch2.Handle, beta, alpha);
if (res == IntPtr.Zero)
CheckForErrors();
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
@@ -186,6 +189,16 @@ public Tensor addbmm_(Tensor batch1, Tensor batch2, float beta = 1, float alpha
///
public Tensor addcdiv(Tensor tensor1, Tensor tensor2, Scalar value)
{
+ if (AutocastMode.IsAutocastEnabled(this.device.type)) {
+ var st = (ScalarType)THSTensor_type(Handle);
+ var st1 = (ScalarType)THSTensor_type(tensor1.Handle);
+ var st2 = (ScalarType)THSTensor_type(tensor2.Handle);
+ var sts = new[] { st, st1, st2 };
+ if (sts.All(x => x == ScalarType.Float16))
+ (handle, tensor1.handle, tensor2.handle) = AutocastMode.AutoCast(handle, tensor1.handle, tensor2.handle, ScalarType.Float16);
+ if (sts.Any(x => x == ScalarType.Float32))
+ (handle, tensor1.handle, tensor2.handle) = AutocastMode.AutoCast(handle, tensor1.handle, tensor2.handle, ScalarType.Float32);
+ }
var res = THSTensor_addcdiv(Handle, tensor1.Handle, tensor2.Handle, value.Handle);
if (res == IntPtr.Zero)
CheckForErrors();
@@ -237,6 +250,23 @@ public Tensor addcdiv_(Tensor tensor1, Tensor tensor2)
///
public Tensor addcmul(Tensor tensor1, Tensor tensor2, Scalar value)
{
+ if (AutocastMode.IsAutocastEnabled(this.device.type)) {
+ /*
+ * These ops don’t require a particular dtype for stability, but take multiple inputs and require that the inputs’ dtypes match.
+ * If all of the inputs are float16, the op runs in float16.
+ * If any of the inputs is float32, autocast casts all inputs to float32 and runs the op in float32.
+ * https://pytorch.org/docs/stable/amp.html
+ */
+ var st = (ScalarType)THSTensor_type(Handle);
+ var st1 = (ScalarType)THSTensor_type(tensor1.Handle);
+ var st2 = (ScalarType)THSTensor_type(tensor2.Handle);
+ var sts = new[] { st, st1, st2 };
+ if (sts.All(x => x == ScalarType.Float16))
+ (handle, tensor1.handle, tensor2.handle) = AutocastMode.AutoCast(handle, tensor1.handle, tensor2.handle, ScalarType.Float16);
+ if (sts.Any(x => x == ScalarType.Float32))
+ (handle, tensor1.handle, tensor2.handle) = AutocastMode.AutoCast(handle, tensor1.handle, tensor2.handle, ScalarType.Float32);
+ }
+
var res = THSTensor_addcmul(Handle, tensor1.Handle, tensor2.Handle, value.Handle);
if (res == IntPtr.Zero)
CheckForErrors();
@@ -270,6 +300,7 @@ public Tensor addmm(Tensor mat1, Tensor mat2, float beta = 1, float alpha = 1)
var res = THSTensor_addmm(Handle, mat1.Handle, mat2.Handle, beta, alpha);
if (res == IntPtr.Zero)
CheckForErrors();
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
@@ -301,6 +332,7 @@ public Tensor addmv(Tensor mat, Tensor vec, float beta = 1.0f, float alpha = 1.0
var res = THSTensor_addmv(Handle, mat.Handle, vec.Handle, beta, alpha);
if (res == IntPtr.Zero)
CheckForErrors();
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
@@ -332,6 +364,7 @@ public Tensor addr(Tensor vec1, Tensor vec2, float beta = 1.0f, float alpha = 1.
var res = THSTensor_addr(Handle, vec1.Handle, vec2.Handle, beta, alpha);
if (res == IntPtr.Zero)
CheckForErrors();
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
@@ -646,6 +679,7 @@ public Tensor cumsum(long dim, ScalarType? type = null)
{
var res = THSTensor_cumsum(Handle, dim, type.HasValue, (sbyte)type.GetValueOrDefault());
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -660,6 +694,7 @@ public Tensor cumprod(long dim, ScalarType? type = null)
{
var res = THSTensor_cumprod(Handle, dim, type.HasValue, (sbyte)type.GetValueOrDefault());
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -754,6 +789,7 @@ public Tensor exp()
{
var res = THSTensor_exp(Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -786,6 +822,7 @@ public Tensor expm1()
{
var res = THSTensor_expm1(Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -1025,6 +1062,7 @@ public Tensor log()
{
var res = THSTensor_log(Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -1108,6 +1146,7 @@ public Tensor log10()
var res = THSTensor_log10(Handle);
if (res == IntPtr.Zero)
CheckForErrors();
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -1131,6 +1170,7 @@ public Tensor log1p()
var res = THSTensor_log1p(Handle);
if (res == IntPtr.Zero)
CheckForErrors();
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -1154,6 +1194,7 @@ public Tensor log2()
var res = THSTensor_log2(Handle);
if (res == IntPtr.Zero)
CheckForErrors();
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -1385,6 +1426,7 @@ public Tensor pow(Tensor exponent)
{
var res = THSTensor_pow(Handle, exponent.Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32); //https://pytorch.org/docs/stable/amp.html#cuda-ops-that-can-autocast-to-float32
return new Tensor(res);
}
@@ -1409,6 +1451,7 @@ public Tensor pow(Scalar exponent)
{
var res = THSTensor_pow_scalar(Handle, exponent.Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -1433,6 +1476,7 @@ public Tensor reciprocal()
var res = THSTensor_reciprocal(Handle);
if (res == IntPtr.Zero)
CheckForErrors();
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -1528,6 +1572,7 @@ public Tensor rsqrt()
{
var res = THSTensor_rsqrt(Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -1789,6 +1834,15 @@ public Tensor true_divide_(Scalar other)
return this;
}
+ /*public Tensor rtruediv_(Tensor other)
+ {
+ var res = THSTensor_true_divide(other.Handle, Handle);
+ if(res == IntPtr.Zero)
+ CheckForErrors();
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
+ return new Tensor(res);
+ }*/
+
///
/// Returns a new tensor with the truncated integer values of the elements of input.
///
diff --git a/src/TorchSharp/Tensor/Tensor.Trig.cs b/src/TorchSharp/Tensor/Tensor.Trig.cs
index d377e967c..86e5f0865 100644
--- a/src/TorchSharp/Tensor/Tensor.Trig.cs
+++ b/src/TorchSharp/Tensor/Tensor.Trig.cs
@@ -1,6 +1,8 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
using System.Diagnostics.Contracts;
+using System.Linq;
+using TorchSharp.Amp;
using static TorchSharp.PInvoke.NativeMethods;
namespace TorchSharp
@@ -39,6 +41,7 @@ public Tensor asin()
var res = THSTensor_asin(Handle);
if (res == IntPtr.Zero)
CheckForErrors();
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -70,6 +73,7 @@ public Tensor acos()
var res = THSTensor_acos(Handle);
if (res == IntPtr.Zero)
CheckForErrors();
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -140,6 +144,13 @@ public Tensor atan_()
/// The second tensor
public Tensor atan2(Tensor other)
{
+ if (AutocastMode.IsAutocastEnabled()) {
+ var sts = new[] { this.dtype, other.dtype };
+ if (sts.All(x => x == ScalarType.Float16))
+ (handle, other.handle) = AutocastMode.AutoCast(handle, other.handle, ScalarType.Float16);
+ if (sts.Any(x => x == ScalarType.Float32))
+ (handle, other.handle) = AutocastMode.AutoCast(handle, other.handle, ScalarType.Float32);
+ }
var res = THSTensor_atan2(Handle, other.Handle);
if (res == IntPtr.Zero)
CheckForErrors();
@@ -216,6 +227,7 @@ public Tensor tan()
var res = THSTensor_tan(Handle);
if (res == IntPtr.Zero)
CheckForErrors();
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -262,6 +274,7 @@ public Tensor sinh()
var res = THSTensor_sinh(Handle);
if (res == IntPtr.Zero)
CheckForErrors();
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -285,6 +298,7 @@ public Tensor cosh()
var res = THSTensor_cosh(Handle);
if (res == IntPtr.Zero)
CheckForErrors();
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs
index 41b007c9e..9ea71635b 100644
--- a/src/TorchSharp/Tensor/Tensor.cs
+++ b/src/TorchSharp/Tensor/Tensor.cs
@@ -9,6 +9,7 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
+using TorchSharp.Amp;
using TorchSharp.PInvoke;
#nullable enable
@@ -34,7 +35,15 @@ public partial class Tensor : IDisposable
internal DisposeScope? OwningDisposeScope { get; set; }
- internal Tensor(IntPtr handle, bool register = true)
+ /*internal Tensor(IntPtr handle, IntPtr res)
+ {
+ if (AMPManager.GetInstance().IsEnabled) {
+ this.handle = AMPManager.GetInstance().Work(res, handle);
+ } else {
+ this.handle = handle;
+ }
+ }*/
+ internal Tensor(IntPtr handle)
{
this.handle = handle;
System.Threading.Interlocked.Increment(ref _totalCount);
@@ -211,6 +220,10 @@ public IntPtr Handle {
get {
if (handle == IntPtr.Zero)
throw new InvalidOperationException("Tensor invalid -- empty handle.");
+
+ /*if (AMPManager.GetInstance().IsEnabled) {
+ this.handle = AMPManager.GetInstance().Work(handle, this.handle); //MMM.... This is the more abstract of any method Tensor right????
+ }*/
return handle;
}
}
@@ -252,6 +265,7 @@ internal IntPtr MoveHandle()
///
public long numel() => NumberOfElements;
+ public bool is_null() => handle == IntPtr.Zero;
///
/// Get the size of each element in the tensor.
///
@@ -285,6 +299,21 @@ public bool is_nonzero()
return res != 0;
}
+ public bool is_coalesce()
+ {
+ var res = NativeMethods.THSTensor_is_coalesce(Handle);
+ CheckForErrors();
+ return res;
+ }
+
+ public Tensor coalesce()
+ {
+ var res = NativeMethods.THSTensor_coalesce(Handle);
+ if(res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
public bool is_cuda => device.type == DeviceType.CUDA;
public bool is_meta => device.type == DeviceType.META;
@@ -386,7 +415,9 @@ internal void ValidateType(Type dotnetType)
throw new ArgumentException($"{dotnetType.Name} is not compatible with {dtype.ToString()}");
break;
case ScalarType.BFloat16:
- throw new ArgumentException($"No support for {dtype.ToString()} in TorchSharp");
+ if(dotnetType != typeof(Half))
+ throw new ArgumentException($"No support for {dtype.ToString()} in TorchSharp");
+ break;
case ScalarType.Float16:
#if NET6_0_OR_GREATER
if (dotnetType != typeof(Half))
@@ -707,6 +738,7 @@ public bool is_sparse {
public void backward(IList? grad_tensors = null, bool retain_graph = false, bool create_graph = false, IList? inputs = null) =>
torch.autograd.backward(new[] { this }, grad_tensors, retain_graph, create_graph, inputs);
+
///
/// Creates a tensor by loading it from a file.
///
@@ -896,6 +928,24 @@ public Tensor to(ScalarType type, torch.Device device, bool copy = false, bool d
return new Tensor(res);
}
+ /*internal static void to(this IntPtr ptr, ScalarType type)
+ {
+ var res = NativeMethods.THSTensor_to_type(ptr, (sbyte)type);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ if (disposeAfter)
+ this.Dispose();
+ return new Tensor(res);
+ }*/
+ public Tensor to(torch.Device device, ScalarType type, bool non_blocking)
+ {
+ torch.InitializeDevice(device);
+ var res = NativeMethods.THSTensor_to_type_and_device_and_non_blocking(Handle, (sbyte)type, (int)device.type, device.index, non_blocking);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
///
/// Cast the tensor to the given element type.
///
@@ -1613,6 +1663,24 @@ public Tensor index_put_(Tensor value, params TensorIndex[] indices)
}
}
}
+ /*///
+ /// Index into the tensor using Python-like indexing expressions and place a tensor at the index.
+ ///
+ private Tensor index_put_accumulate_(Tensor value, bool accumulate, params TensorIndex[] indices)
+ {
+ EncodeIndices(indices, out var arrKindAndStarts, out var arrStops, out var arrSteps, out var arrTensors);
+ unsafe {
+ fixed (long* ptrKindAndStarts = arrKindAndStarts, ptrStops = arrStops, ptrSteps = arrSteps) {
+ fixed (IntPtr* ptrTensors = arrTensors) {
+ NativeMethods.THSTensor_index_put_accumulate_(Handle, (IntPtr)ptrKindAndStarts, (IntPtr)ptrStops, (IntPtr)ptrSteps, (IntPtr)ptrTensors, indices.Length, value.Handle, accumulate);
+ CheckForErrors();
+ GC.KeepAlive(indices); // don't release or finalize Tensor indices whose handles have been put into ptrTensors
+ GC.KeepAlive(value);
+ return this;
+ }
+ }
+ }
+ }*/
///
/// Index into the tensor using Python-like indexing expressions and place a tensor at the index.
@@ -1622,7 +1690,51 @@ public Tensor index_put_(Tensor value, params Tensor[] indices)
return index_put_(value, indices.Select(t => TensorIndex.Tensor(t)).ToArray());
}
+ /*public Tensor index_put_(Tensor value, bool accumulate, params TensorIndex[] indices)
+ {
+ return index_put_accumulate_(value, accumulate, indices);
+ }
+ public Tensor index_put_(Tensor value, bool accumulate, params Tensor[] indices)
+ {
+ return index_put_accumulate_(value, accumulate, indices.Select(t => TensorIndex.Tensor(t)).ToArray());
+ }
+ ///
+ /// Index into the tensor using Python-like indexing expressions and place a tensor at the index.
+ ///
+ private Tensor index_put_accumulate(Tensor value, bool accumulate, params TensorIndex[] indices)
+ {
+ EncodeIndices(indices, out var arrKindAndStarts, out var arrStops, out var arrSteps, out var arrTensors);
+ unsafe {
+ fixed (long* ptrKindAndStarts = arrKindAndStarts, ptrStops = arrStops, ptrSteps = arrSteps) {
+ fixed (IntPtr* ptrTensors = arrTensors) {
+ var res = NativeMethods.THSTensor_index_put_accumulate(Handle, (IntPtr)ptrKindAndStarts, (IntPtr)ptrStops, (IntPtr)ptrSteps, (IntPtr)ptrTensors, indices.Length, value.Handle, accumulate);
+ CheckForErrors();
+ GC.KeepAlive(indices); // don't release or finalize Tensor indices whose handles have been put into ptrTensors
+ GC.KeepAlive(value);
+ if(res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+ }
+ }
+ }*/
+
+ /*///
+ /// Index into the tensor using Python-like indexing expressions and place a tensor at the index.
+ ///
+ public Tensor index_put(Tensor value, params Tensor[] indices)
+ {
+ return index_put(value, indices.Select(t => TensorIndex.Tensor(t)).ToArray());
+ }*/
+ /*public Tensor index_put(Tensor value, bool accumulate, params TensorIndex[] indices)
+ {
+ return index_put_accumulate(value, accumulate, indices);
+ }
+ public Tensor index_put(Tensor value, bool accumulate, params Tensor[] indices)
+ {
+ return index_put_accumulate(value, accumulate, indices.Select(t => TensorIndex.Tensor(t)).ToArray());
+ }*/
///
/// Index into the tensor using Python-like indexing expressions and place a scalar tensor at the index.
///
@@ -1882,6 +1994,17 @@ public Tensor reshape(params long[] shape)
}
}
+ public Tensor resize_(params long[] shape)
+ {
+ unsafe {
+ fixed (long* pshape = shape) {
+ NativeMethods.THSTensor_resize_(Handle, (IntPtr)pshape, shape.Length);
+ }
+ }
+
+ return this;
+ }
+
///
/// Flattens input by reshaping it into a one-dimensional tensor.
///
@@ -3124,6 +3247,7 @@ public Tensor baddbmm(Tensor batch1, Tensor batch2, float beta = 1, float alpha
{
var res = NativeMethods.THSTensor_baddbmm(Handle, batch1.Handle, batch2.Handle, beta, alpha);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
@@ -3136,6 +3260,7 @@ public Tensor bmm(Tensor batch2)
{
var res = NativeMethods.THSTensor_bmm(Handle, batch2.Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
@@ -3459,6 +3584,7 @@ public Tensor erfinv()
{
var res = NativeMethods.THSTensor_erfinv(Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -4207,7 +4333,8 @@ public Tensor mean()
}
///
- /// Returns the q-th quantiles of all elements in the input tensor, doing a linear interpolation when the q-th quantile lies between two data points.
+ /// Returns the q-th quantiles of all elements in the input tensor, doing a
+ /// interpolation when the q-th quantile lies between two data points.
///
/// 1D tensor of quantile values in the range [0, 1]
/// The dimension to reduce.
@@ -4426,6 +4553,7 @@ public Tensor dist(Tensor other, float p = 2.0f)
{
var res = NativeMethods.THSTensor_dist(Handle, other.Handle, p);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -4437,6 +4565,7 @@ public Tensor norm(float p = 2.0f)
{
var res = NativeMethods.THSTensor_norm(Handle, p);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -4447,6 +4576,7 @@ public Tensor norm(int dim, bool keepdim = false, float p = 2.0f)
{
var res = NativeMethods.THSTensor_norm_along_dimension(Handle, dim, keepdim, p);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -4491,6 +4621,7 @@ public Tensor prelu(Tensor target)
{
var res = NativeMethods.THSTensor_prelu(Handle, target.Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res);
return new Tensor(res);
}
@@ -4536,6 +4667,7 @@ public Tensor renorm(float p, long dim, float maxnorm)
{
var res = NativeMethods.THSTensor_renorm(Handle, p, dim, maxnorm);
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -4958,6 +5090,7 @@ public Tensor prod(ScalarType? type = null)
{
var res = NativeMethods.THSTensor_prod(Handle, type.HasValue, (sbyte)type.GetValueOrDefault());
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -4968,6 +5101,7 @@ public Tensor prod(long dim, bool keepdim = false, ScalarType? type = null)
{
var res = NativeMethods.THSTensor_prod_along_dimensions(Handle, dim, keepdim, type.HasValue, (sbyte)type.GetValueOrDefault());
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -4978,6 +5112,7 @@ public Tensor sum(ScalarType? type = null)
{
var res = NativeMethods.THSTensor_sum(Handle, type.HasValue, (sbyte)type.GetValueOrDefault());
if (res == IntPtr.Zero) { CheckForErrors(); }
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -5852,6 +5987,13 @@ public Tensor scatter_(long dim, Tensor index, Tensor src)
///
public Tensor scatter_add(long dim, Tensor index, Tensor src)
{
+ if (AutocastMode.IsAutocastEnabled()) {
+ var sts = new[] { this.dtype, index.dtype, src.dtype };
+ if (sts.All(x => x == ScalarType.Float16))
+ (handle, index.handle, src.handle) = AutocastMode.AutoCast(handle, index.handle, src.handle, ScalarType.Float16);
+ if (sts.Any(x => x == ScalarType.Float32))
+ (handle, index.handle, src.handle) = AutocastMode.AutoCast(handle, index.handle, src.handle, ScalarType.Float32);
+ }
var res = NativeMethods.THSTensor_scatter_add(Handle, dim, index.Handle, src.Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
return new Tensor(res);
@@ -7483,5 +7625,16 @@ internal static Tensor InstantiateTensorWithLeakSafeTypeChange(IntPtr handle, Sc
}
return tensor;
}
+ public static void _amp_foreach_non_finite_check_and_unscale(Tensor found_inf, Tensor inv_scale)
+ {
+ if (found_inf.numel() == 1)
+ throw new Exception("found_inf must be a 1-element tensor.");
+ if (found_inf.numel() == 1)
+ throw new Exception("found_inf must be a 1-element tensor.");
+ if (found_inf.numel() == 1)
+ throw new Exception("found_inf must be a 1-element tensor.");
+ if (found_inf.numel() == 1)
+ throw new Exception("found_inf must be a 1-element tensor.");
+ }
}
}
\ No newline at end of file
diff --git a/src/TorchSharp/Tensor/torch.Amp.cs b/src/TorchSharp/Tensor/torch.Amp.cs
new file mode 100644
index 000000000..319afe65c
--- /dev/null
+++ b/src/TorchSharp/Tensor/torch.Amp.cs
@@ -0,0 +1,46 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using static TorchSharp.PInvoke.NativeMethods;
+
+namespace TorchSharp
+{
+ public static partial class torch
+ {
+ public static void _amp_foreach_non_finite_check_and_unscale_(IList tensors, Tensor found_inf, Tensor inv_scale)
+ {
+ using var ts = new PinnedArray();
+ IntPtr tens = ts.CreateArray(tensors.Select(x => x.Handle).ToArray());
+ THSAmp_amp_foreach_non_finite_check_and_unscale_(tens, ts.Array.Length, found_inf.Handle, inv_scale.Handle);
+ }
+
+ public static torch.Tensor amp_update_scale_(Tensor self, Tensor growth_tracker, Tensor found_inf, double scale_growth_factor, double scale_backoff_factor, long growth_interval)
+ {
+ var res = THSAmp_amp_update_scale_(self.Handle, growth_tracker.Handle, found_inf.Handle, scale_growth_factor, scale_backoff_factor, growth_interval);
+ if(res == IntPtr.Zero)
+ torch.CheckForErrors();
+ return new Tensor(res);
+ }
+ public static torch.Tensor amp_update_scale_out(Tensor outt, Tensor self, Tensor growth_tracker, Tensor found_inf, double scale_growth_factor, double scale_backoff_factor, long growth_interval)
+ {
+ var res = THSAmp_amp_update_scale_out(outt.Handle, self.Handle, growth_tracker.Handle, found_inf.Handle, scale_growth_factor, scale_backoff_factor, growth_interval);
+ if(res == IntPtr.Zero)
+ torch.CheckForErrors();
+ return new Tensor(res);
+ }
+ public static torch.Tensor amp_update_scale_outf(Tensor self, Tensor growth_tracker, Tensor found_inf, double scale_growth_factor, double scale_backoff_factor, long growth_interval, Tensor outt)
+ {
+ var res = THSAmp_amp_update_scale_outf(self.Handle, growth_tracker.Handle, found_inf.Handle, scale_growth_factor, scale_backoff_factor, growth_interval, outt.Handle);
+ if(res == IntPtr.Zero)
+ torch.CheckForErrors();
+ return new Tensor(res);
+ }
+ public static (torch.Tensor, torch.Tensor) amp_update_scale(Tensor self, Tensor growth_tracker, Tensor found_inf, double scale_growth_factor, double scale_backoff_factor, long growth_interval)
+ {
+ var res = THSAMP_amp_update_scale(self.Handle, growth_tracker.Handle, found_inf.Handle, scale_growth_factor, scale_backoff_factor, growth_interval, out var res1);
+ if(res == IntPtr.Zero || res1 == IntPtr.Zero)
+ torch.CheckForErrors();
+ return (new Tensor(res), new Tensor(res1));
+ }
+ }
+}
diff --git a/src/TorchSharp/Tensor/torch.Autocast.cs b/src/TorchSharp/Tensor/torch.Autocast.cs
new file mode 100644
index 000000000..12e86d46d
--- /dev/null
+++ b/src/TorchSharp/Tensor/torch.Autocast.cs
@@ -0,0 +1,62 @@
+using System;
+using static TorchSharp.PInvoke.NativeMethods;
+
+namespace TorchSharp
+{
+ public static partial class torch
+ {
+ public static bool is_autocast_cache_enabled()
+ {
+ return THSAmp_is_autocast_cache_enabled();
+ }
+
+ public static bool is_autocast_available(DeviceType device)
+ {
+ //https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/init.cpp
+ return THSAmp_is_autocast_available((int)device);
+ }
+ public static bool is_autocast_enabled(DeviceType device)
+ {
+ return THSAmp_is_autocast_enabled((int)device);
+ //return THSAmp_is_autocast_cache_enabled();
+ }
+ public static ScalarType get_autocast_dtype(DeviceType device)
+ {
+ return (ScalarType)THSAmp_get_autocast_dtype((int)device);
+ }
+
+
+ public static int autocast_increment_nesting()
+ {
+ return THSAmp_autocast_increment_nesting();
+ }
+
+ public static int autocast_decrement_nesting()
+ {
+ return THSAmp_autocast_decrement_nesting();
+ }
+
+ public static void set_autocast_enabled(DeviceType device, bool enabled)
+ {
+ THSAmp_set_autocast_enabled((int)device,enabled);
+ }
+
+ public static void set_autocast_dtype(DeviceType device, ScalarType dtype)
+ {
+ THSAmp_set_autocast_dtype((int)device, (sbyte)dtype);
+ }
+ public static void set_autocast_cache_enabled(bool enabled)
+ {
+ THSAmp_set_autocast_cache_enabled(enabled);
+ }
+ public static void set_autocast_cache_enabled(DeviceType device, ScalarType dtype)
+ {
+ THSAmp_set_autocast_dtype((int)device, (sbyte)dtype);
+ }
+
+ public static void clear_autocast_cache()
+ {
+ THSAmp_clear_autocast_cache();
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/TorchSharp/Tensor/torch.BlasAndLapackOperations.cs b/src/TorchSharp/Tensor/torch.BlasAndLapackOperations.cs
index ff6f4d6b1..6024eee82 100644
--- a/src/TorchSharp/Tensor/torch.BlasAndLapackOperations.cs
+++ b/src/TorchSharp/Tensor/torch.BlasAndLapackOperations.cs
@@ -143,7 +143,8 @@ public static Tensor cholesky_inverse(Tensor input, bool upper = false)
// https://pytorch.org/docs/stable/generated/torch.cholesky_solve
///
- /// Solves a linear system of equations with a positive semidefinite matrix to be inverted given its Cholesky factor matrix u.
+ /// Solves a
+ /// system of equations with a positive semidefinite matrix to be inverted given its Cholesky factor matrix u.
///
///
public static Tensor cholesky_solve(Tensor input, Tensor input2, bool upper = false)
@@ -317,6 +318,7 @@ public static (Tensor P, Tensor? L, Tensor? U) lu_unpack(Tensor LU_data, Tensor
///
///
public static Tensor mm(Tensor input, Tensor target) => input.mm(target);
+
// https://pytorch.org/docs/stable/generated/torch.mv
///
diff --git a/src/TorchSharp/Tensor/torch.OtherOperations.cs b/src/TorchSharp/Tensor/torch.OtherOperations.cs
index dfdb4a6ff..bf0c377b8 100644
--- a/src/TorchSharp/Tensor/torch.OtherOperations.cs
+++ b/src/TorchSharp/Tensor/torch.OtherOperations.cs
@@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using TorchSharp.Amp;
using TorchSharp.PInvoke;
using static TorchSharp.PInvoke.NativeMethods;
@@ -167,6 +168,7 @@ public static Tensor cdist(
var res = THSTensor_cdist(x1.Handle, x2.Handle, p, (long)compute_mode);
if (res == IntPtr.Zero)
CheckForErrors();
+ res = AutocastMode.AutoCast(res, ScalarType.Float32);
return new Tensor(res);
}
@@ -229,6 +231,8 @@ public static Tensor cov(Tensor input, long correction = 1, Tensor? fweights = n
///
public static Tensor cross(Tensor input, Scalar other, long dim = 0L) => input.cross(other, dim);
+ public static Tensor cross(Tensor input, Tensor other, long dim = 0L) => input.cross(other, dim);
+
// https://pytorch.org/docs/stable/generated/torch.cummax
public static (Tensor values, Tensor indices) cummax(Tensor input, long dim) => input.cummax(dim);
diff --git a/src/TorchSharp/Tensor/torch.Utilities.cs b/src/TorchSharp/Tensor/torch.Utilities.cs
index 460a42e67..6e89134e8 100644
--- a/src/TorchSharp/Tensor/torch.Utilities.cs
+++ b/src/TorchSharp/Tensor/torch.Utilities.cs
@@ -2,6 +2,8 @@
#nullable enable
using System;
using System.Diagnostics.Contracts;
+using TorchSharp.Modules;
+using TorchSharp.PInvoke;
using static TorchSharp.PInvoke.NativeMethods;
#nullable enable
@@ -80,5 +82,23 @@ public static ScalarType promote_types(ScalarType type1, ScalarType type2)
[Obsolete("not implemented", true)]
public static void _assert(Func condition, string message) => throw new NotImplementedException();
+
+ public static void PrintModule(torch.nn.Module module)
+ {
+ if (module is Dropout2d drop2d) {
+ Console.WriteLine($"{module.GetName()}({drop2d.p}, {drop2d.inplace})");
+ return;
+ }
+
+ if (module is LayerNorm ln) {
+ string str= "[";
+ for (int i = 0; i < ln._normalized_shape.Length; i++)
+ str += ln._normalized_shape[i] + ",";
+ str = str.TrimEnd(',')+"]";
+ Console.WriteLine($"{module.GetName()}({ln._eps}, {str})");
+ return;
+ }
+ NativeMethods.THSNN_Print_Module(module.handle);
+ }
}
}
\ No newline at end of file
diff --git a/src/TorchSharp/Torch.cs b/src/TorchSharp/Torch.cs
index 728fa9ccd..15d4338fa 100644
--- a/src/TorchSharp/Torch.cs
+++ b/src/TorchSharp/Torch.cs
@@ -11,6 +11,7 @@
using System.Text.RegularExpressions;
using TorchSharp.Modules;
using TorchSharp.PInvoke;
+using TorchSharp.Utils;
using static TorchSharp.PInvoke.NativeMethods;
#nullable enable
@@ -55,7 +56,8 @@ public static partial class torch
public static string __version__ => libtorchPackageVersion;
- internal static bool TryLoadNativeLibraryFromFile(string path, StringBuilder trace) {
+ internal static bool TryLoadNativeLibraryFromFile(string path, StringBuilder trace)
+ {
bool ok;
try {
trace.AppendLine($" Trying to load native component {path}");
@@ -207,8 +209,7 @@ private static void LoadNativeBackend(bool useCudaBackend, out StringBuilder? tr
throw new NotSupportedException(message);
}
}
- }
- else {
+ } else {
trace.AppendLine(" Giving up, TorchSharp.dll does not appear to have been loaded from package directories");
}
if (!ok) {
@@ -268,8 +269,7 @@ private static bool CopyNativeComponentsIntoSingleDirectory(string packagesDir,
public static bool TryInitializeDeviceType(DeviceType deviceType)
{
- if (deviceType == DeviceType.MPS && !isAppleSilicon)
- {
+ if (deviceType == DeviceType.MPS && !isAppleSilicon) {
return false;
}
@@ -283,8 +283,7 @@ public static bool TryInitializeDeviceType(DeviceType deviceType)
public static void InitializeDeviceType(DeviceType deviceType)
{
- if (deviceType == DeviceType.MPS && !isAppleSilicon)
- {
+ if (deviceType == DeviceType.MPS && !isAppleSilicon) {
throw new InvalidOperationException($"Torch device type 'MPS' is not available on this platform.");
}
@@ -511,7 +510,6 @@ public static Linear fuse_linear_bn_eval(Linear linear, BatchNorm bn)
public static partial class cuda
{
-
/// This must be a separate method to the failure to bind DllImport THSTorchCuda_is_available
/// is not raised as early as a DllImportException
[System.Runtime.CompilerServices.MethodImpl(System.Runtime.CompilerServices.MethodImplOptions.NoInlining)]
@@ -581,6 +579,69 @@ public static void synchronize(Device? device = null)
TryInitializeDeviceType(device?.type ?? DeviceType.CUDA);
THSTorchCuda_synchronize(device?.index ?? -1);
}
+
+ public static bool is_bf16_supported()
+ {
+ //TODO IMPLEMENT: torch.cuda.current_device() https://github.com/pytorch/pytorch/blob/a4cc6b85dc14d5895499f89f39181c00196d336e/torch/cuda/__init__.py#L153
+ if (int.TryParse(cudaVersion.Split('.')[0], out int res)){
+
+ //TODO: Implement get device properties
+ //WARNING: Need Major compute capability version https://github.com/pytorch/pytorch/blob/a4cc6b85dc14d5895499f89f39181c00196d336e/torch/cuda/__init__.py#L161
+ var compute = torch.cuda.get_compute_capability();
+ if (res >= 11 && compute.major >= 8)
+ return true;
+ }
+
+ return check_bf16_tensor_supported(torch.CUDA);
+ }
+
+ private static bool check_bf16_tensor_supported(torch.Device dev)
+ {
+ try {
+ var va = torch.tensor(new float[] { 1.0f }, dtype: ScalarType.BFloat16, device: dev);
+ return true;
+ } catch {
+ return false;
+ }
+ }
+
+ public static (int major, int minor) get_compute_capability()
+ {
+ return (THSCuda_get_major_compute_capability(), THSCuda_get_minor_compute_capability());
+ }
+
+ public static (int res, int id, ulong free, ulong total) get_free_total_memory(int device)
+ {
+ int id = 0;
+ ulong f=0;
+ ulong t=0;
+ int res = THSCuda_get_free_total(device, ref id, ref f, ref t);
+ return (res, id, f, t);
+ }
+
+ public static int get_device_count(ref int count)
+ {
+ return THSCuda_get_device_count(ref count);
+ }
+
+ public static ulong get_total_memory(int device)
+ {
+ return THSCuda_get_total_memory(device);
+ }
+ public static ulong get_global_total_memory(int device)
+ {
+ return THSCuda_get_global_total_memory(device);
+ }
+ /*public static cudaDeviceProp get_device_prop(int device)
+ {
+#if CUDA_TOOLKIT_FOUND
+ cudaDeviceProp cdp = new cudaDeviceProp();
+ throw new NotImplementedException("Implement the cudaDeviceProp THSCuda");
+ //return cdp;
+#else
+ return null;
+#endif
+ }*/
}
///
diff --git a/src/TorchSharp/TorchSharp.csproj b/src/TorchSharp/TorchSharp.csproj
index 5a102f34e..c959e0619 100644
--- a/src/TorchSharp/TorchSharp.csproj
+++ b/src/TorchSharp/TorchSharp.csproj
@@ -3,14 +3,14 @@
- net6.0;netstandard2.0
- 9.0
- TorchSharp
- true
- false
- false
- false
- $(DefineConstants);LIBTORCH_$(LibTorchPackageVersion.Replace('.', '_'));CUDA_$(CudaVersionDot.Replace('.', '_'))
+ netstandard2.0;net6.0
+ 9.0
+ TorchSharp
+ true
+ false
+ false
+ false
+ $(DefineConstants);LIBTORCH_$(LibTorchPackageVersion.Replace('.', '_'));CUDA_$(CudaVersionDot.Replace('.', '_'))
@@ -19,6 +19,11 @@
+
+
+
+
+
@@ -49,29 +54,40 @@
-
- $(PackDependsOn);
- RealPack
-
- True
- ..\..\build\TorchSharp.snk
+
+ $(PackDependsOn);
+ RealPack
+
+ True
+ ..\..\build\TorchSharp.snk
+
+
+
+
+ 4
+
+
+
+
+ 4
-
+
-
+
+
-
+
diff --git a/src/TorchSharp/Utils/BFloat16.cs b/src/TorchSharp/Utils/BFloat16.cs
new file mode 100644
index 000000000..fef947389
--- /dev/null
+++ b/src/TorchSharp/Utils/BFloat16.cs
@@ -0,0 +1,49 @@
+using System;
+using System.Collections.Generic;
+using System.Runtime.InteropServices;
+using System.Text;
+
+namespace System
+{
+ [StructLayout(LayoutKind.Sequential,Pack=2)]
+ public struct BFloat16
+ {
+ [MarshalAs(UnmanagedType.I2)]
+ private short x;
+ public struct from_bits_t{};
+ }
+
+ /*
+ *
+struct alignas(2) BFloat16 {
+ uint16_t x;
+
+ // HIP wants __host__ __device__ tag, CUDA does not
+#if defined(USE_ROCM)
+ C10_HOST_DEVICE BFloat16() = default;
+#else
+ BFloat16() = default;
+#endif
+
+ struct from_bits_t {};
+ static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
+ return from_bits_t();
+ }
+
+ constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t)
+ : x(bits) {}
+ inline C10_HOST_DEVICE BFloat16(float value);
+ inline C10_HOST_DEVICE operator float() const;
+
+#if defined(__CUDACC__) && !defined(USE_ROCM)
+ inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
+ explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
+#endif
+
+#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
+ inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value);
+ explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const;
+#endif
+};
+ */
+}
diff --git a/src/TorchSharp/Utils/FastTensorAccessor.cs b/src/TorchSharp/Utils/FastTensorAccessor.cs
new file mode 100644
index 000000000..142b95d6c
--- /dev/null
+++ b/src/TorchSharp/Utils/FastTensorAccessor.cs
@@ -0,0 +1,712 @@
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Runtime.InteropServices;
+using static TorchSharp.PInvoke.NativeMethods;
+
+namespace TorchSharp.Utils
+{
+ ///
+ /// TensorAccessor is used to present the contents of a tensor or tensor view to the .NET world as an ordered collection
+ /// of values that integrates well with things like LINQ and foreach loops in the .NET world.
+ ///
+ /// The type of the tensor elements.
+ public sealed class FastTensorAccessor : IDisposable, IEnumerable where T : unmanaged
+ {
+ internal FastTensorAccessor(torch.Tensor tensor)
+ {
+ if (tensor.device_type != DeviceType.CPU) {
+ throw new InvalidOperationException("Reading data from non-CPU memory is not supported. Move or copy the tensor to the cpu before reading.");
+ }
+
+ var strides = tensor.stride();
+ for (var i = 0; i < strides.Length; i++) {
+ if (strides[i] < 0)
+ throw new NotImplementedException($"Negative tensor strides are not currently supported. tensor.strides({i}) == {strides[i]}");
+ }
+
+ // Get the data from native code.
+
+ unsafe {
+ var res = THSTensor_data(tensor.Handle);
+ if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ // NOTE: there is no safety here.
+ _tensor_data_ptr = res;
+ }
+
+ _tensor = tensor; // Keep the tensor alive now that everything is alright.
+ }
+
+ ///
+ /// This is important for performance because only called with CopyTo, CopyFrom. Is not necesary in each invocation call tensor.numel() because that use intensive CPU.
+ /// This temporary count avoid so much use CPU. The Property act as method.
+ /// If tensor is for example 640*640*3 = 1.228.800, property invoke 1 millons times!!!
+ /// If we only want copy is not necesary call that method so many times.
+ /// For some reason the method numel() use so much cpu.
+ ///
+ internal long TempCount = -1;
+ public long Count => _tensor?.numel() ?? 0;
+
+ public bool IsReadOnly => false;
+
+ public T[] ToArray()
+ {
+ if (_tensor.ndim < 2)
+ return (T[])ToNDArray();
+
+ var shps = _tensor.shape;
+ TempCount = 1;
+ for (int i = 0; i < shps.Length; i++)
+ TempCount *= shps[i]; //Theorically the numel is simple as product of each element shape
+
+ if (_tensor.is_contiguous()) { //This is very fast. And work VERY WELL
+ unsafe {
+ return new Span(_tensor_data_ptr.ToPointer(), Convert.ToInt32(TempCount)).ToArray();
+ }
+ }
+ var result = new T[TempCount];
+ CopyTo(result);
+ return result;
+ }
+
+ ///
+ /// Extract tensor data as a multi-dimensional .NET array, with the same number of dimensions as the tensor.
+ ///
+ /// An array object, which should be cast to the concrete array type.
+ public Array ToNDArray()
+ {
+ var shape = _tensor.shape;
+ var strides = _tensor.stride();
+ switch (_tensor.ndim) {
+ default:
+ return ToNDArray(shape, strides);
+ case 0:
+ unsafe {
+ var result = new T[1];
+ T* ptr = (T*)_tensor_data_ptr;
+ result[0] = ptr[0];
+ return result;
+ }
+ case 1:
+ unsafe {
+ var result = new T[shape[0]];
+ T* ptr = (T*)_tensor_data_ptr;
+ for (long i0 = 0, off0 = 0; i0 < shape[0]; i0++, off0 += strides[0]) {
+ result[i0] = ptr[off0];
+ }
+ return result;
+ }
+ case 2:
+ unsafe {
+ var result = new T[shape[0], shape[1]];
+ T* ptr = (T*)_tensor_data_ptr;
+ for (long i0 = 0, off0 = 0; i0 < shape[0]; i0++, off0 += strides[0]) {
+ for (long i1 = 0, off1 = off0; i1 < shape[1]; i1++, off1 += strides[1]) {
+ result[i0, i1] = ptr[off1];
+ }
+ }
+ return result;
+ }
+ case 3:
+ unsafe {
+ var result = new T[shape[0], shape[1], shape[2]];
+ T* ptr = (T*)_tensor_data_ptr;
+ for (long i0 = 0, off0 = 0; i0 < shape[0]; i0++, off0 += strides[0]) {
+ for (long i1 = 0, off1 = off0; i1 < shape[1]; i1++, off1 += strides[1]) {
+ for (long i2 = 0, off2 = off1; i2 < shape[2]; i2++, off2 += strides[2]) {
+ result[i0, i1, i2] = ptr[off2];
+ }
+ }
+ }
+ return result;
+ }
+ case 4:
+ unsafe {
+ var result = new T[shape[0], shape[1], shape[2], shape[3]];
+ T* ptr = (T*)_tensor_data_ptr;
+ for (long i0 = 0, off0 = 0; i0 < shape[0]; i0++, off0 += strides[0]) {
+ for (long i1 = 0, off1 = off0; i1 < shape[1]; i1++, off1 += strides[1]) {
+ for (long i2 = 0, off2 = off1; i2 < shape[2]; i2++, off2 += strides[2]) {
+ for (long i3 = 0, off3 = off2; i3 < shape[3]; i3++, off3 += strides[3]) {
+ result[i0, i1, i2, i3] = ptr[off3];
+ }
+ }
+ }
+ }
+ return result;
+ }
+ case 5:
+ unsafe {
+ var result = new T[shape[0], shape[1], shape[2], shape[3], shape[4]];
+ T* ptr = (T*)_tensor_data_ptr;
+ for (long i0 = 0, off0 = 0; i0 < shape[0]; i0++, off0 += strides[0]) {
+ for (long i1 = 0, off1 = off0; i1 < shape[1]; i1++, off1 += strides[1]) {
+ for (long i2 = 0, off2 = off1; i2 < shape[2]; i2++, off2 += strides[2]) {
+ for (long i3 = 0, off3 = off2; i3 < shape[3]; i3++, off3 += strides[3]) {
+ for (long i4 = 0, off4 = off3; i4 < shape[4]; i4++, off4 += strides[4]) {
+ result[i0, i1, i2, i3, i4] = ptr[off4];
+ }
+ }
+ }
+ }
+ }
+ return result;
+ }
+ case 6:
+ unsafe {
+ var result = new T[shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]];
+ T* ptr = (T*)_tensor_data_ptr;
+ for (long i0 = 0, off0 = 0; i0 < shape[0]; i0++, off0 += strides[0]) {
+ for (long i1 = 0, off1 = off0; i1 < shape[1]; i1++, off1 += strides[1]) {
+ for (long i2 = 0, off2 = off1; i2 < shape[2]; i2++, off2 += strides[2]) {
+ for (long i3 = 0, off3 = off2; i3 < shape[3]; i3++, off3 += strides[3]) {
+ for (long i4 = 0, off4 = off3; i4 < shape[4]; i4++, off4 += strides[4]) {
+ for (long i5 = 0, off5 = off4; i5 < shape[5]; i5++, off5 += strides[5]) {
+ result[i0, i1, i2, i3, i4, i5] = ptr[off5];
+ }
+ }
+ }
+ }
+ }
+ }
+ return result;
+ }
+ }
+ }
+
+ private Array ToNDArray(long[] shape, long[] strides)
+ {
+ Array array = Array.CreateInstance(typeof(T), shape);
+ long[] indexes = new long[_tensor.ndim];
+ long[] off = new long[_tensor.ndim];
+
+ while (true) {
+ unsafe {
+ T* ptr = (T*)_tensor_data_ptr;
+ array.SetValue(ptr[off[array.Rank - 1]], indexes);
+ }
+
+ for (int i = array.Rank - 1; i >= 0; i--) {
+ if (indexes[i] < shape[i] - 1) {
+ indexes[i]++;
+ off[i] += strides[i];
+ for (int j = i; j < array.Rank - 1; j++)
+ off[j + 1] = off[j];
+ break;
+ } else {
+ if (i == 0) {
+ return array;
+ }
+ indexes[i] = 0;
+ }
+ }
+ }
+ }
+
+ ///
+ /// Access elements of the underlying tensor / tensor view.
+ ///
+ /// A linear index into the data.
+ ///
+ public T this[params long[] indices] {
+ get {
+ long index = 0;
+ if (indices.Length == 1) {
+ index = indices[0];
+ validate(index);
+ unsafe {
+ T* ptr = (T*)_tensor_data_ptr;
+ return ptr[TranslateIndex(index, _tensor)];
+ }
+ } else {
+ unsafe {
+ T* ptr = (T*)_tensor_data_ptr;
+ return ptr[TranslateIndex(indices, _tensor)];
+ }
+ }
+ }
+ set {
+ long index = 0;
+ if (indices.Length == 1) {
+ index = indices[0];
+ validate(index);
+ unsafe {
+ T* ptr = (T*)_tensor_data_ptr;
+ ptr[TranslateIndex(indices, _tensor)] = value;
+ }
+ } else {
+ unsafe {
+ T* ptr = (T*)_tensor_data_ptr;
+ ptr[TranslateIndex(indices, _tensor)] = value;
+ }
+ }
+ }
+ }
+
+ private void validate(long index)
+ {
+ if (index >= Count) throw new IndexOutOfRangeException();
+ }
+
+ public void CopyTo(T[] array, int arrayIndex = 0, long tensorIndex = 0)
+ {
+ int idx = arrayIndex;
+ /*if (_tensor.is_contiguous()) {
+ if (typeof(T) == typeof(float)) {
+ float[] ff = new float[TempCount];
+ Marshal.Copy(_tensor_data_ptr, ff, 0,ff.Length);
+ }
+ }*/
+ //Because the contiguous cause arange from tensorIndex to Numel. So is not necesary "create" array of arange, i said "create" because in fact enumerable do not create itself. Very cool.
+ if (_tensor.is_contiguous()) {
+ for (long i = tensorIndex; i < TempCount; i++)
+ unsafe { array[i] = ((T*)_tensor_data_ptr)[i]; }
+ return;
+ }
+ foreach (int offset in GetSubsequentIndices(tensorIndex)) {
+ if (idx >= array.Length) break;
+ unsafe { array[idx] = ((T*)_tensor_data_ptr)[offset]; }
+ idx += 1;
+ }
+ }
+
+ public void CopyTo(Span array, int arrayIndex = 0, long tensorIndex = 0)
+ {
+ int idx = arrayIndex;
+ foreach (int offset in GetSubsequentIndices(tensorIndex)) {
+ if (idx >= array.Length) break;
+ unsafe { array[idx] = ((T*)_tensor_data_ptr)[offset]; }
+ idx += 1;
+ }
+ }
+
+ public void CopyFrom(T[] array, int arrayIndex = 0, long tensorIndex = 0)
+ {
+ int idx = arrayIndex;
+ foreach (int offset in GetSubsequentIndices(tensorIndex)) {
+ if (idx >= array.Length) break;
+ unsafe { ((T*)_tensor_data_ptr)[offset] = array[idx]; }
+ idx += 1;
+ }
+ }
+
+ public void CopyFrom(ReadOnlySpan array, int arrayIndex = 0, long tensorIndex = 0)
+ {
+ int idx = arrayIndex;
+ foreach (int offset in GetSubsequentIndices(tensorIndex)) {
+ if (idx >= array.Length) break;
+ unsafe { ((T*)_tensor_data_ptr)[offset] = array[idx]; }
+ idx += 1;
+ }
+ }
+
+ ///
+ /// Translates a linear index within the span represented by the accessor to a linear index
+ /// used by the underlying tensor. The two should only be different if the tensor is a view
+ /// rather than an allocated tensor.
+ ///
+ private static long TranslateIndex(long idx, torch.Tensor tensor)
+ {
+ if (idx >= tensor.numel() || idx < 0)
+ throw new ArgumentOutOfRangeException($"{idx} in a collection of ${tensor.numel()} elements.");
+
+ if (tensor.is_contiguous() || idx == 0) return idx;
+
+ long result = 0;
+ var shape = tensor.shape;
+ var strides = tensor.stride();
+
+ for (var i = shape.Length - 1; i >= 0; i--) {
+ idx = Math.DivRem(idx, shape[i], out long s);
+ result += s * strides[i];
+ }
+
+ return result;
+ }
+ ///
+ /// WARNING: Test purpose not use in production
+ ///
+ private long TranslateIndexNonStatic(long idx, torch.Tensor tensor)
+ {
+ if (idx >= TempCount || idx < 0)
+ throw new ArgumentOutOfRangeException($"{idx} in a collection of ${tensor.numel()} elements.");
+
+ if (tensor.is_contiguous() || idx == 0) return idx;
+
+ long result = 0;
+ var shape = tensor.shape;
+ var strides = tensor.stride();
+
+ for (var i = shape.Length - 1; i >= 0; i--) {
+ idx = Math.DivRem(idx, shape[i], out long s);
+ result += s * strides[i];
+ }
+
+ return result;
+ }
+ private static long TranslateIndex(long[] idx, torch.Tensor tensor)
+ {
+ long result = 0;
+ var shape = tensor.shape;
+ var strides = tensor.stride();
+
+ for (var i = shape.Length - 1; i >= 0; i--) {
+ if (idx[i] >= shape[i] || idx[i] < 0)
+ throw new IndexOutOfRangeException($"{idx[i]} >= {shape[i]} in dimension {i}.");
+ result += idx[i] * strides[i];
+ }
+
+ return result;
+ }
+
+ internal static T ReadItemAt(torch.Tensor tensor, long index)
+ {
+ if (tensor.device_type != DeviceType.CPU) {
+ throw new InvalidOperationException("Reading data from non-CPU memory is not supported. Move or copy the tensor to the cpu before reading.");
+ }
+
+ tensor.ValidateType(typeof(T));
+
+ var strides = tensor.stride();
+ for (var i = 0; i < strides.Length; i++) {
+ if (strides[i] < 0)
+ throw new NotImplementedException($"Negative tensor strides are not currently supported. tensor.strides({i}) == {strides[i]}");
+ }
+
+ unsafe {
+ var res = THSTensor_data(tensor.Handle);
+ if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ // NOTE: there is no safety here.
+ T* ptr = (T*)res;
+ return ptr[TranslateIndex(index, tensor)];
+ }
+ }
+
+ ///
+ /// Compare two tensors element-wise.
+ ///
+ /// A tensor
+ /// Another tensor
+ ///
+ public static bool operator ==(FastTensorAccessor left, FastTensorAccessor right)
+ {
+ if (left.Count != right.Count) return false;
+
+ var lEnum = left.GetEnumerator();
+ var rEnum = right.GetEnumerator();
+
+ while (lEnum.MoveNext() && rEnum.MoveNext()) {
+ if (!lEnum.Current.Equals(rEnum.Current))
+ return false;
+ }
+ return true;
+ }
+
+ ///
+ /// Compare two tensors element-wise.
+ ///
+ /// A tensor
+ /// Another tensor
+ ///
+ public static bool operator !=(FastTensorAccessor left, FastTensorAccessor right)
+ {
+ return !(left == right);
+ }
+
+
+ private IEnumerable GetSubsequentIndices(long startingIndex)
+ {
+ //TempCount = Count;
+
+ if (startingIndex < 0 || startingIndex >= TempCount)
+ throw new ArgumentOutOfRangeException(nameof(startingIndex));
+
+ if (TempCount <= 1) {
+ if (TempCount == 0) {
+ return Enumerable.Empty();
+ }
+
+ return new List() { 0 };
+ //return (new long[] { 0 }).AsEnumerable();
+ }
+
+ if (_tensor.is_contiguous()) {
+ return ContiguousIndices(startingIndex);
+ }
+
+ var stride = _tensor.stride();
+ Debug.Assert(stride.Length > 0);
+
+ if (stride.Length == 1) {
+ return SimpleIndices(startingIndex, stride[0]);
+ }
+
+ return MultiDimensionIndices(startingIndex);
+ }
+ private IEnumerable MultiDimensionIndices(long startingIndex)
+ {
+ long[] shape = _tensor.shape;
+ long[] stride = _tensor.stride();
+ long[] inds = new long[stride.Length];
+
+ long index = startingIndex;
+ //long offset = TranslateIndex(startingIndex, _tensor);
+ long offset = TranslateIndexNonStatic(startingIndex, _tensor); //WARNING: Test purpose not use in production
+
+ while (true) {
+
+ index += 1;
+
+ yield return offset;
+
+ if (index >= TempCount) break;
+
+ for (int i = inds.Length - 1; ; i--) {
+ Debug.Assert(i >= 0);
+ offset += stride[i];
+ if (++inds[i] < shape[i])
+ break;
+
+ // Overflow of current dimension so rewind accordingly.
+ // Can't overflow the final (left-most) dimension.
+ Debug.Assert(i > 0);
+ // Note: for perf, this multiplication could be done once up front and cached in an array.
+ offset -= inds[i] * stride[i];
+ inds[i] = 0;
+ }
+ }
+ }
+
+ private IEnumerable SimpleIndices(long startingIndex, long stride)
+ {
+ long index = startingIndex;
+ //long offset = TranslateIndex(startingIndex, _tensor);
+ long offset = TranslateIndexNonStatic(startingIndex, _tensor); //WARNING: Test purpose not use in production
+
+ while (index < TempCount) {
+ yield return offset;
+ offset += stride;
+ index += 1;
+ }
+ }
+
+ private IEnumerable ContiguousIndices(long startingIndex)
+ {
+ // If there was an overload for Enumerable.Range that
+ // produced long integers, we wouldn't need this implementation.
+
+ long index = startingIndex;
+ while (index < TempCount) {
+ yield return index;
+ index += 1;
+ }
+ }
+
+
+ ///
+ /// Compare two tensors element-wise.
+ ///
+ /// Another tensor
+ ///
+ public override bool Equals(object obj)
+ {
+ var left = this;
+ var right = obj as FastTensorAccessor;
+ if (right == null) return false;
+
+ if (left._tensor_data_ptr == right._tensor_data_ptr) return true;
+ if (left.Count != right.Count) return false;
+ for (long i = 0; i < left.Count; i++) {
+ if (!left[i].Equals(right[i])) return false;
+ }
+ return true;
+ }
+
+ public override int GetHashCode()
+ {
+ return base.GetHashCode();
+ }
+
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return GetEnumerator();
+ }
+
+ public void Dispose()
+ {
+ Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+
+ private void Dispose(bool disposing)
+ {
+ _tensor_data_ptr = IntPtr.Zero;
+ // Clear the tensor that we've been keeping alive.
+ _tensor = null;
+ }
+
+ private torch.Tensor _tensor; // Keeping it alive.
+ private IntPtr _tensor_data_ptr;
+
+#if true
+ public IEnumerator GetEnumerator()
+ {
+ if (TempCount <= 1) {
+ if (TempCount == 0)
+ return Enumerable.Empty().GetEnumerator();
+ return new T[1] { this[0] }.AsEnumerable().GetEnumerator();
+ }
+ /*if (Count <= 1) {
+ if (Count == 0)
+ return Enumerable.Empty().GetEnumerator();
+ return new T[1] { this[0] }.AsEnumerable().GetEnumerator();
+ }*/
+
+ if (_tensor.is_contiguous()) {
+ return new SimpleAtorImpl(this, 1);
+ }
+
+ var stride = _tensor.stride();
+ Debug.Assert(stride.Length > 0);
+
+ if (stride.Length == 1) {
+ return new SimpleAtorImpl(this, stride[0]);
+ }
+
+ return new GeneralAtorImpl(this, stride);
+ }
+
+ private class SimpleAtorImpl : IEnumerator
+ {
+ private FastTensorAccessor _span;
+ private readonly long _count;
+ private readonly long _stride;
+
+ // State.
+ private long _index;
+ private long _offset;
+ private T _current;
+
+ public SimpleAtorImpl(FastTensorAccessor span, long stride)
+ {
+ _span = span;
+ _count = span.TempCount;
+ Debug.Assert(_count > 0);
+ _stride = stride;
+ Reset();
+ }
+
+ public T Current => _current;
+ object IEnumerator.Current => Current;
+
+ public void Dispose()
+ {
+ _span = null;
+ Reset();
+ }
+
+ public bool MoveNext()
+ {
+ if (_index < 0) {
+ _index = 0;
+ _offset = 0;
+ } else if (++_index >= _count) {
+ Reset();
+ return false;
+ } else {
+ _offset += _stride;
+ }
+
+ unsafe { _current = ((T*)_span._tensor_data_ptr)[_offset]; }
+ return true;
+ }
+
+ public void Reset()
+ {
+ _index = -1;
+ _offset = -1;
+ _current = default;
+ }
+ }
+
+ private class GeneralAtorImpl : IEnumerator
+ {
+ private FastTensorAccessor