@@ -44,43 +44,51 @@ public IEnumerable<DiagnosticRecord> AnalyzeScript(Ast ast, string fileName)
44
44
if ( ast == null ) throw new ArgumentNullException ( Strings . NullAstErrorMessage ) ;
45
45
46
46
var sbAst = ast as ScriptBlockAst ;
47
+
48
+ var requiresTransformationAttribute = false ;
49
+ var psVersion = Helper . Instance . GetPSVersion ( ) ;
50
+ if ( psVersion != null && psVersion . Major < 5 )
51
+ {
52
+ requiresTransformationAttribute = true ;
53
+ }
54
+
55
+ // do not run the rule if the script requires PS version 5
56
+ // but PSSA in invoked through PS version < 5
47
57
if ( sbAst != null
48
- && sbAst . ScriptRequirements != null
49
- && sbAst . ScriptRequirements . RequiredPSVersion != null
50
- && sbAst . ScriptRequirements . RequiredPSVersion . Major == 5 )
58
+ && sbAst . ScriptRequirements != null
59
+ && sbAst . ScriptRequirements . RequiredPSVersion != null
60
+ && sbAst . ScriptRequirements . RequiredPSVersion . Major == 5
61
+ && requiresTransformationAttribute )
51
62
{
52
- yield break ;
63
+ yield break ;
53
64
}
54
65
55
66
IEnumerable < Ast > funcDefAsts = ast . FindAll ( testAst => testAst is FunctionDefinitionAst , true ) ;
56
67
IEnumerable < Ast > scriptBlockAsts = ast . FindAll ( testAst => testAst is ScriptBlockAst , true ) ;
57
68
58
- string funcName ;
69
+ List < DiagnosticRecord > diagnosticRecords = new List < DiagnosticRecord > ( ) ;
59
70
60
71
foreach ( FunctionDefinitionAst funcDefAst in funcDefAsts )
61
72
{
62
- funcName = funcDefAst . Name ;
63
-
73
+ IEnumerable < ParameterAst > parameterAsts = null ;
64
74
if ( funcDefAst . Parameters != null )
65
75
{
66
- foreach ( ParameterAst parameter in funcDefAst . Parameters )
67
- {
68
- if ( WrongCredentialUsage ( parameter ) )
69
- {
70
- yield return new DiagnosticRecord ( string . Format ( CultureInfo . CurrentCulture , Strings . UsePSCredentialTypeError , funcName ) , funcDefAst . Extent , GetName ( ) , DiagnosticSeverity . Warning , fileName ) ;
71
- }
72
- }
76
+ parameterAsts = funcDefAst . Parameters ;
73
77
}
74
78
75
- if ( funcDefAst . Body . ParamBlock != null )
79
+ if ( funcDefAst . Body . ParamBlock != null
80
+ && funcDefAst . Body . ParamBlock . Parameters != null )
76
81
{
77
- foreach ( ParameterAst parameter in funcDefAst . Body . ParamBlock . Parameters )
78
- {
79
- if ( WrongCredentialUsage ( parameter ) )
80
- {
81
- yield return new DiagnosticRecord ( string . Format ( CultureInfo . CurrentCulture , Strings . UsePSCredentialTypeError , funcName ) , funcDefAst . Extent , GetName ( ) , DiagnosticSeverity . Warning , fileName ) ;
82
- }
83
- }
82
+ parameterAsts = funcDefAst . Body . ParamBlock . Parameters ;
83
+ }
84
+
85
+ if ( parameterAsts != null )
86
+ {
87
+ diagnosticRecords . AddRange ( GetViolations (
88
+ parameterAsts ,
89
+ string . Format ( CultureInfo . CurrentCulture , Strings . UsePSCredentialTypeError , funcDefAst . Name ) ,
90
+ fileName ,
91
+ requiresTransformationAttribute ) ) ;
84
92
}
85
93
}
86
94
@@ -94,28 +102,68 @@ public IEnumerable<DiagnosticRecord> AnalyzeScript(Ast ast, string fileName)
94
102
95
103
if ( scriptBlockAst . ParamBlock != null && scriptBlockAst . ParamBlock . Parameters != null )
96
104
{
97
- foreach ( ParameterAst parameter in scriptBlockAst . ParamBlock . Parameters )
105
+ diagnosticRecords . AddRange ( GetViolations (
106
+ scriptBlockAst . ParamBlock . Parameters ,
107
+ string . Format ( CultureInfo . CurrentCulture , Strings . UsePSCredentialTypeErrorSB ) ,
108
+ fileName ,
109
+ requiresTransformationAttribute ) ) ;
110
+ }
111
+ }
112
+
113
+ foreach ( var dr in diagnosticRecords )
114
+ {
115
+ yield return dr ;
116
+ }
117
+ }
118
+
119
+ private IEnumerable < DiagnosticRecord > GetViolations (
120
+ IEnumerable < ParameterAst > parameterAsts ,
121
+ string errorMessage ,
122
+ string fileName ,
123
+ bool requiresTransformationAttribute )
124
+ {
125
+ foreach ( ParameterAst parameter in parameterAsts )
126
+ {
127
+ if ( WrongCredentialUsage ( parameter , requiresTransformationAttribute ) )
98
128
{
99
- if ( WrongCredentialUsage ( parameter ) )
100
- {
101
- yield return new DiagnosticRecord ( string . Format ( CultureInfo . CurrentCulture , Strings . UsePSCredentialTypeErrorSB ) , scriptBlockAst . Extent , GetName ( ) , DiagnosticSeverity . Warning , fileName ) ;
102
- }
129
+ yield return new DiagnosticRecord (
130
+ errorMessage ,
131
+ parameter . Extent ,
132
+ GetName ( ) ,
133
+ DiagnosticSeverity . Warning ,
134
+ fileName ) ;
103
135
}
104
136
}
105
- }
106
137
}
107
138
108
- private bool WrongCredentialUsage ( ParameterAst parameter )
139
+ private bool WrongCredentialUsage ( ParameterAst parameter , bool requiresTransformationAttribute )
109
140
{
110
141
if ( parameter . Name . VariablePath . UserPath . Equals ( "Credential" , StringComparison . OrdinalIgnoreCase ) )
111
142
{
112
143
var psCredentialType = parameter . Attributes . FirstOrDefault ( paramAttribute => ( paramAttribute . TypeName . IsArray && ( paramAttribute . TypeName as ArrayTypeName ) . ElementType . GetReflectionType ( ) == typeof ( PSCredential ) )
113
144
|| paramAttribute . TypeName . GetReflectionType ( ) == typeof ( PSCredential ) ) ;
114
145
115
- var credentialAttribute = parameter . Attributes . FirstOrDefault ( paramAttribute => paramAttribute . TypeName . GetReflectionType ( ) == typeof ( CredentialAttribute ) ) ;
146
+ if ( psCredentialType == null )
147
+ {
148
+ return true ;
149
+ }
150
+
151
+ if ( ! requiresTransformationAttribute && psCredentialType != null )
152
+ {
153
+ return false ;
154
+ }
155
+
156
+ var credentialAttribute = parameter . Attributes . FirstOrDefault (
157
+ paramAttribute =>
158
+ paramAttribute . TypeName . GetReflectionType ( ) == typeof ( CredentialAttribute )
159
+ || paramAttribute . TypeName . FullName . Equals (
160
+ "System.Management.Automation.Credential" ,
161
+ StringComparison . OrdinalIgnoreCase ) ) ;
116
162
117
163
// check that both exists and pscredentialtype comes before credential attribute
118
- if ( psCredentialType != null && credentialAttribute != null && psCredentialType . Extent . EndOffset <= credentialAttribute . Extent . StartOffset )
164
+ if ( psCredentialType != null
165
+ && credentialAttribute != null
166
+ && psCredentialType . Extent . EndOffset <= credentialAttribute . Extent . StartOffset )
119
167
{
120
168
return false ;
121
169
}
0 commit comments