Skip to content

Commit 35fcb62

Browse files
author
Kapil Borle
committed
Add method to check executing PSVersion from Helper
1 parent cb7068e commit 35fcb62

File tree

3 files changed

+120
-25
lines changed

3 files changed

+120
-25
lines changed

Engine/Commands/InvokeScriptAnalyzerCommand.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
using System.Collections.Concurrent;
2727
using System.Threading;
2828
using System.Management.Automation.Runspaces;
29+
using System.Collections;
2930

3031
namespace Microsoft.Windows.PowerShell.ScriptAnalyzer.Commands
3132
{
@@ -220,6 +221,12 @@ protected override void BeginProcessing()
220221
this);
221222
Helper.Instance.Initialize();
222223

224+
var psVersionTable = this.SessionState.PSVariable.GetValue("PSVersionTable") as Hashtable;
225+
if (psVersionTable != null)
226+
{
227+
Helper.Instance.SetPSVersionTable(psVersionTable);
228+
}
229+
223230
string[] rulePaths = Helper.ProcessCustomRulePaths(customRulePath,
224231
this.SessionState, recurseCustomRulePath);
225232
if (IsFileParameterSet())

Engine/Helper.cs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
using System.Globalization;
2121
using Microsoft.Windows.PowerShell.ScriptAnalyzer.Generic;
2222
using System.Management.Automation.Runspaces;
23+
using System.Collections;
2324

2425
namespace Microsoft.Windows.PowerShell.ScriptAnalyzer
2526
{
@@ -36,6 +37,7 @@ public class Helper
3637
private Object getCommandLock = new object();
3738
private readonly static Version minSupportedPSVersion = new Version(3, 0);
3839
private Dictionary<string, Dictionary<string, object>> ruleArguments;
40+
private PSVersionTable psVersionTable;
3941

4042
#endregion
4143

@@ -1860,6 +1862,26 @@ public static bool IsModuleManifest(string filepath, Version powershellVersion =
18601862
// check if the keys given in module manifest are a proper subset of Keys
18611863
return map.Keys.All(x => allKeys.Concat(deprecatedKeys).Contains(x, StringComparer.OrdinalIgnoreCase));
18621864
}
1865+
1866+
public void SetPSVersionTable(Hashtable psVersionTable)
1867+
{
1868+
if (psVersionTable == null)
1869+
{
1870+
throw new ArgumentNullException("psVersionTable");
1871+
}
1872+
1873+
this.psVersionTable = new PSVersionTable(psVersionTable);
1874+
}
1875+
1876+
#if CORECLR
1877+
public SemanticVersion GetPSVersion()
1878+
#else
1879+
public Version GetPSVersion()
1880+
#endif
1881+
{
1882+
return psVersionTable == null ? null : psVersionTable.PSVersion;
1883+
}
1884+
18631885
#endregion Methods
18641886
}
18651887

@@ -3802,6 +3824,49 @@ private int GetIndex(T vertex)
38023824
int idx;
38033825
return vertexIndexMap.TryGetValue(vertex, out idx) ? idx : -1;
38043826
}
3827+
}
3828+
3829+
internal class PSVersionTable
3830+
{
3831+
private readonly string psVersionKey = "PSVersion";
3832+
private readonly string psEditionKey = "PSEdition";
3833+
#if CORECLR
3834+
public SemanticVersion PSVersion { get; private set; }
3835+
#else
3836+
public Version PSVersion { get; private set; }
3837+
#endif
3838+
public string PSEdition { get; private set; }
38053839

3840+
public PSVersionTable(Hashtable psVersionTable)
3841+
{
3842+
if (psVersionTable == null)
3843+
{
3844+
throw new ArgumentNullException("psVersionTable");
3845+
}
3846+
3847+
if (!psVersionTable.ContainsKey(psVersionKey))
3848+
{
3849+
throw new ArgumentException("Input PSVersionTable does not contain PSVersion key"); // TODO localize
3850+
}
3851+
3852+
#if CORECLR
3853+
PSVersion = psVersionTable[psVersionKey] as SemanticVersion;
3854+
#else
3855+
PSVersion = psVersionTable[psVersionKey] as Version;
3856+
#endif
3857+
if (PSVersion == null)
3858+
{
3859+
throw new ArgumentException("Input PSVersionTable has invalid PSVersion value type"); // TODO localize
3860+
}
3861+
3862+
if (psVersionTable.ContainsKey(psEditionKey))
3863+
{
3864+
PSEdition = psVersionTable[psEditionKey] as string;
3865+
if (PSEdition == null)
3866+
{
3867+
throw new ArgumentException("Input PSVersionTable has invalid PSEdition value type"); // TODO localize
3868+
}
3869+
}
3870+
}
38063871
}
38073872
}

Rules/UsePSCredentialType.cs

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,40 +44,51 @@ public IEnumerable<DiagnosticRecord> AnalyzeScript(Ast ast, string fileName)
4444
if (ast == null) throw new ArgumentNullException(Strings.NullAstErrorMessage);
4545

4646
var sbAst = ast as ScriptBlockAst;
47+
48+
var requiresTransformationAttribute = false;
49+
var psVersion = Helper.Instance.GetPSVersion();
50+
if (psVersion != null && psVersion.Major < 5)
51+
{
52+
requiresTransformationAttribute = true;
53+
}
54+
4755
if (sbAst != null
4856
&& sbAst.ScriptRequirements != null
4957
&& sbAst.ScriptRequirements.RequiredPSVersion != null
5058
&& sbAst.ScriptRequirements.RequiredPSVersion.Major == 5)
5159
{
52-
yield break;
60+
if (requiresTransformationAttribute)
61+
{
62+
yield break;
63+
}
5364
}
5465

5566
IEnumerable<Ast> funcDefAsts = ast.FindAll(testAst => testAst is FunctionDefinitionAst, true);
5667
IEnumerable<Ast> scriptBlockAsts = ast.FindAll(testAst => testAst is ScriptBlockAst, true);
5768

58-
string funcName;
59-
IEnumerable<DiagnosticRecord> diagnosticRecords = Enumerable.Empty<DiagnosticRecord>();
69+
List<DiagnosticRecord> diagnosticRecords = new List<DiagnosticRecord>();
6070

6171
foreach (FunctionDefinitionAst funcDefAst in funcDefAsts)
6272
{
63-
funcName = funcDefAst.Name;
64-
73+
IEnumerable<ParameterAst> parameterAsts = null;
6574
if (funcDefAst.Parameters != null)
6675
{
67-
diagnosticRecords.Concat(GetViolations(
68-
funcDefAst.Parameters,
69-
funcDefAst,
70-
string.Format(CultureInfo.CurrentCulture, Strings.UsePSCredentialTypeError, funcName),
71-
fileName));
76+
parameterAsts = funcDefAst.Parameters;
7277
}
7378

74-
if (funcDefAst.Body.ParamBlock != null)
79+
if (funcDefAst.Body.ParamBlock != null
80+
&& funcDefAst.Body.ParamBlock.Parameters != null)
7581
{
76-
diagnosticRecords.Concat(GetViolations(
77-
funcDefAst.Body.ParamBlock.Parameters,
78-
funcDefAst,
79-
string.Format(CultureInfo.CurrentCulture, Strings.UsePSCredentialTypeError, funcName),
80-
fileName));
82+
parameterAsts = funcDefAst.Body.ParamBlock.Parameters;
83+
}
84+
85+
if (parameterAsts != null)
86+
{
87+
diagnosticRecords.AddRange(GetViolations(
88+
parameterAsts,
89+
string.Format(CultureInfo.CurrentCulture, Strings.UsePSCredentialTypeError, funcDefAst.Name),
90+
fileName,
91+
requiresTransformationAttribute));
8192
}
8293
}
8394

@@ -91,11 +102,11 @@ public IEnumerable<DiagnosticRecord> AnalyzeScript(Ast ast, string fileName)
91102

92103
if (scriptBlockAst.ParamBlock != null && scriptBlockAst.ParamBlock.Parameters != null)
93104
{
94-
diagnosticRecords.Concat(GetViolations(
105+
diagnosticRecords.AddRange(GetViolations(
95106
scriptBlockAst.ParamBlock.Parameters,
96-
scriptBlockAst,
97107
string.Format(CultureInfo.CurrentCulture, Strings.UsePSCredentialTypeErrorSB),
98-
fileName));
108+
fileName,
109+
requiresTransformationAttribute));
99110
}
100111
}
101112

@@ -107,35 +118,47 @@ public IEnumerable<DiagnosticRecord> AnalyzeScript(Ast ast, string fileName)
107118

108119
private IEnumerable<DiagnosticRecord> GetViolations(
109120
IEnumerable<ParameterAst> parameterAsts,
110-
Ast parentAst,
111121
string errorMessage,
112-
string fileName)
122+
string fileName,
123+
bool requiresTransformationAttribute)
113124
{
114125
foreach (ParameterAst parameter in parameterAsts)
115126
{
116-
if (WrongCredentialUsage(parameter))
127+
if (WrongCredentialUsage(parameter, requiresTransformationAttribute))
117128
{
118129
yield return new DiagnosticRecord(
119130
errorMessage,
120-
parentAst.Extent,
131+
parameter.Extent,
121132
GetName(),
122133
DiagnosticSeverity.Warning,
123134
fileName);
124135
}
125136
}
126137
}
127138

128-
private bool WrongCredentialUsage(ParameterAst parameter)
139+
private bool WrongCredentialUsage(ParameterAst parameter, bool requiresTransformationAttribute)
129140
{
130141
if (parameter.Name.VariablePath.UserPath.Equals("Credential", StringComparison.OrdinalIgnoreCase))
131142
{
132143
var psCredentialType = parameter.Attributes.FirstOrDefault(paramAttribute => (paramAttribute.TypeName.IsArray && (paramAttribute.TypeName as ArrayTypeName).ElementType.GetReflectionType() == typeof(PSCredential))
133144
|| paramAttribute.TypeName.GetReflectionType() == typeof(PSCredential));
134145

146+
if (psCredentialType == null)
147+
{
148+
return true;
149+
}
150+
151+
if (!requiresTransformationAttribute && psCredentialType != null)
152+
{
153+
return false;
154+
}
155+
135156
var credentialAttribute = parameter.Attributes.FirstOrDefault(paramAttribute => paramAttribute.TypeName.GetReflectionType() == typeof(CredentialAttribute));
136157

137158
// check that both exists and pscredentialtype comes before credential attribute
138-
if (psCredentialType != null && credentialAttribute != null && psCredentialType.Extent.EndOffset <= credentialAttribute.Extent.StartOffset)
159+
if (psCredentialType != null
160+
&& credentialAttribute != null
161+
&& psCredentialType.Extent.EndOffset <= credentialAttribute.Extent.StartOffset)
139162
{
140163
return false;
141164
}

0 commit comments

Comments
 (0)