@@ -12,7 +12,7 @@ import (
1212 "github.com/stretchr/testify/assert"
1313)
1414
15- func TestRenderInternal (t * testing.T ) {
15+ func TestRenderInternalAttrs (t * testing.T ) {
1616 cases := []struct {
1717 input , protected , recovered string
1818 }{
@@ -30,7 +30,7 @@ func TestRenderInternal(t *testing.T) {
3030 for _ , c := range cases {
3131 var r RenderInternal
3232 out := & bytes.Buffer {}
33- in := r .init ("sec" , out )
33+ in := r .init ("sec" , out , "" )
3434 protected := r .ProtectSafeAttrs (template .HTML (c .input ))
3535 assert .EqualValues (t , c .protected , protected )
3636 _ , _ = io .WriteString (in , string (protected ))
@@ -41,7 +41,7 @@ func TestRenderInternal(t *testing.T) {
4141 var r1 , r2 RenderInternal
4242 protected := r1 .ProtectSafeAttrs (`<div class="test"></div>` )
4343 assert .EqualValues (t , `<div class="test"></div>` , protected , "non-initialized RenderInternal should not protect any attributes" )
44- _ = r1 .init ("sec" , nil )
44+ _ = r1 .init ("sec" , nil , "" )
4545 protected = r1 .ProtectSafeAttrs (`<div class="test"></div>` )
4646 assert .EqualValues (t , `<div data-attr-class="sec:test"></div>` , protected )
4747 assert .Equal (t , "data-attr-class" , r1 .SafeAttr ("class" ))
@@ -54,8 +54,37 @@ func TestRenderInternal(t *testing.T) {
5454 assert .Empty (t , recovered )
5555
5656 out2 := & bytes.Buffer {}
57- in2 := r2 .init ("sec-other" , out2 )
57+ in2 := r2 .init ("sec-other" , out2 , "" )
5858 _ , _ = io .WriteString (in2 , string (protected ))
5959 _ = in2 .Close ()
6060 assert .Equal (t , `<div data-attr-class="sec:test"></div>` , out2 .String (), "different secureID should not recover the value" )
6161}
62+
63+ func TestRenderInternalExtraHead (t * testing.T ) {
64+ t .Run ("HeadExists" , func (t * testing.T ) {
65+ out := & bytes.Buffer {}
66+ var r RenderInternal
67+ in := r .init ("sec" , out , `<MY-TAG>` )
68+ _ , _ = io .WriteString (in , `<head>any</head>` )
69+ _ = in .Close ()
70+ assert .Equal (t , `<head><MY-TAG>any</head>` , out .String ())
71+ })
72+
73+ t .Run ("HeadNotExists" , func (t * testing.T ) {
74+ out := & bytes.Buffer {}
75+ var r RenderInternal
76+ in := r .init ("sec" , out , `<MY-TAG>` )
77+ _ , _ = io .WriteString (in , `<div></div>` )
78+ _ = in .Close ()
79+ assert .Equal (t , `<MY-TAG><div></div>` , out .String ())
80+ })
81+
82+ t .Run ("NotHTML" , func (t * testing.T ) {
83+ out := & bytes.Buffer {}
84+ var r RenderInternal
85+ in := r .init ("sec" , out , `<MY-TAG>` )
86+ _ , _ = io .WriteString (in , `<any>` )
87+ _ = in .Close ()
88+ assert .Equal (t , `<any>` , out .String ())
89+ })
90+ }
0 commit comments